Lowering ShapeOp to Shape Dialect for Ranked Tensors

The lowering extracts the ShapeType and performs a per element conversion to
the correct element type. If the op is input is dynamic but has constant length
dimensions, those dimensions are instead created as constants.

PiperOrigin-RevId: 314768753
Change-Id: Iede948fea424aedb05f8e37883e8fe1041b3fb87
This commit is contained in:
Robert Suderman 2020-06-04 11:28:10 -07:00 committed by TensorFlower Gardener
parent eb5226a3e0
commit 991d76e6ac
2 changed files with 117 additions and 1 deletions

View File

@ -1526,6 +1526,67 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
return %0 : tensor<8xcomplex<f32>>
}
//===----------------------------------------------------------------------===//
// Shape op legalization.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @shape_1D
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<1xi32>
}
// CHECK-LABEL: func @shape_2D
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_with_const
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_rankless
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
//===----------------------------------------------------------------------===//
// Transpose op legalization.
//===----------------------------------------------------------------------===//

View File

@ -4541,6 +4541,60 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> {
}
};
// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
// dialect lowerings. This involves extracting the shape type, extracting and
// converting each dimension to a known integer type, and repacking into a final
// tensor.
class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::ShapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
// If the shape is static it can be canonicalized.
if (!input_ty || input_ty.hasStaticShape()) {
return failure();
}
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
auto element_ty = result_ty.getElementType();
int64_t rank = input_ty.getRank();
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto index_ty = RankedTensorType::get({1}, element_ty);
llvm::SmallVector<Value, 4> dim_values;
for (int64_t i = 0; i < rank; ++i) {
if (!input_ty.isDynamicDim(i)) {
auto dim_attr = DenseElementsAttr::get(
index_ty,
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
dim_values.push_back(index);
continue;
}
auto extent_op = rewriter.create<shape::GetExtentOp>(
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
auto index_op = rewriter.create<shape::SizeToIndexOp>(
op.getLoc(), rewriter.getIndexType(), extent_op);
auto int_op =
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
auto from_tensor = rewriter.create<TensorFromElementsOp>(
op.getLoc(), int_op.getResult());
auto reshape_op =
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
dim_values.push_back(reshape_op);
}
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
rewriter.getI64IntegerAttr(0));
return success();
}
};
// Converts a TF QR op to HLO.
class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
public:
@ -5120,7 +5174,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertShapeOp, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,