Lowering ShapeOp to Shape Dialect for Ranked Tensors
The lowering extracts the ShapeType and performs a per element conversion to the correct element type. If the op is input is dynamic but has constant length dimensions, those dimensions are instead created as constants. PiperOrigin-RevId: 314768753 Change-Id: Iede948fea424aedb05f8e37883e8fe1041b3fb87
This commit is contained in:
parent
eb5226a3e0
commit
991d76e6ac
@ -1526,6 +1526,67 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
|
||||
return %0 : tensor<8xcomplex<f32>>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shape op legalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @shape_1D
|
||||
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_2D
|
||||
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1
|
||||
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
|
||||
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
|
||||
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
|
||||
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
|
||||
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
|
||||
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
|
||||
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
|
||||
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_with_const
|
||||
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_rankless
|
||||
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transpose op legalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4541,6 +4541,60 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
|
||||
// dialect lowerings. This involves extracting the shape type, extracting and
|
||||
// converting each dimension to a known integer type, and repacking into a final
|
||||
// tensor.
|
||||
class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::ShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||
// If the shape is static it can be canonicalized.
|
||||
if (!input_ty || input_ty.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto element_ty = result_ty.getElementType();
|
||||
|
||||
int64_t rank = input_ty.getRank();
|
||||
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
|
||||
|
||||
auto index_ty = RankedTensorType::get({1}, element_ty);
|
||||
llvm::SmallVector<Value, 4> dim_values;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (!input_ty.isDynamicDim(i)) {
|
||||
auto dim_attr = DenseElementsAttr::get(
|
||||
index_ty,
|
||||
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
|
||||
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
|
||||
dim_values.push_back(index);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto extent_op = rewriter.create<shape::GetExtentOp>(
|
||||
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
|
||||
auto index_op = rewriter.create<shape::SizeToIndexOp>(
|
||||
op.getLoc(), rewriter.getIndexType(), extent_op);
|
||||
auto int_op =
|
||||
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
|
||||
auto from_tensor = rewriter.create<TensorFromElementsOp>(
|
||||
op.getLoc(), int_op.getResult());
|
||||
auto reshape_op =
|
||||
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
|
||||
dim_values.push_back(reshape_op);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts a TF QR op to HLO.
|
||||
class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
public:
|
||||
@ -5120,7 +5174,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
|
||||
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||
ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertShapeOp, ConvertSigmoidOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
|
Loading…
x
Reference in New Issue
Block a user