Updated for better compatibility with the Shape dialect
Used a tensor<index> to tensor<ixx> cast instead of manually constructing. This increases support for rankless tensors. PiperOrigin-RevId: 316173553 Change-Id: I10b2b6f94c56eca7b2827b772e9a4532fc7530dd
This commit is contained in:
parent
694b502173
commit
3a2d113417
|
@ -1545,61 +1545,33 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
|
|||
// CHECK-LABEL: func @shape_1D
|
||||
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[C0:.*]] = shape.const_size 0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]]
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
|
||||
// CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_2D
|
||||
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[C0:.*]] = shape.const_size 0
|
||||
// CHECK: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], %[[C0]]
|
||||
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
|
||||
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
|
||||
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
|
||||
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
|
||||
// CHECK: %[[C1:.*]] = shape.const_size 1
|
||||
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], %[[C1]]
|
||||
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
|
||||
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
|
||||
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
|
||||
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
|
||||
// CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_with_const
|
||||
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[C0:.*]] = shape.const_size 0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]]
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_rankless
|
||||
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
|
|
|
@ -4610,45 +4610,16 @@ class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
|
|||
LogicalResult matchAndRewrite(TF::ShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||
// If the shape is static it can be canonicalized.
|
||||
if (!input_ty || input_ty.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto element_ty = result_ty.getElementType();
|
||||
|
||||
int64_t rank = input_ty.getRank();
|
||||
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
|
||||
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
auto index_ty = RankedTensorType::get({1}, element_ty);
|
||||
llvm::SmallVector<Value, 4> dim_values;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (!input_ty.isDynamicDim(i)) {
|
||||
auto dim_attr = DenseElementsAttr::get(
|
||||
index_ty,
|
||||
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
|
||||
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
|
||||
dim_values.push_back(index);
|
||||
continue;
|
||||
}
|
||||
auto index_tensor =
|
||||
RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
|
||||
auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
||||
op.getLoc(), index_tensor, shape_op);
|
||||
|
||||
auto extent_op =
|
||||
rewriter.create<shape::GetExtentOp>(op.getLoc(), shape_op, i);
|
||||
auto index_op = rewriter.create<shape::SizeToIndexOp>(
|
||||
op.getLoc(), rewriter.getIndexType(), extent_op);
|
||||
auto int_op =
|
||||
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
|
||||
auto from_tensor = rewriter.create<TensorFromElementsOp>(
|
||||
op.getLoc(), int_op.getResult());
|
||||
auto reshape_op =
|
||||
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
|
||||
dim_values.push_back(reshape_op);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue