Updated for better compatibility with the Shape dialect

Used a tensor<index> to tensor<ixx> cast instead of manually constructing.
This increases support for rankless tensors.

PiperOrigin-RevId: 316173553
Change-Id: I10b2b6f94c56eca7b2827b772e9a4532fc7530dd
This commit is contained in:
Robert Suderman 2020-06-12 14:05:39 -07:00 committed by TensorFlower Gardener
parent 694b502173
commit 3a2d113417
2 changed files with 17 additions and 74 deletions

View File

@ -1545,61 +1545,33 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
// CHECK-LABEL: func @shape_1D // CHECK-LABEL: func @shape_1D
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> { func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: %[[C0:.*]] = shape.const_size 0 // CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]] // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
// CHECK: return [[CONCAT]] // CHECK: return [[CAST]]
return %0 : tensor<1xi32> return %0 : tensor<1xi32>
} }
// CHECK-LABEL: func @shape_2D // CHECK-LABEL: func @shape_2D
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> { func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: %[[C0:.*]] = shape.const_size 0 // CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
// CHECK: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], %[[C0]] // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
// CHECK: %[[C1:.*]] = shape.const_size 1
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], %[[C1]]
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32> %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]] // CHECK: return [[CAST]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_with_const
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: %[[C0:.*]] = shape.const_size 0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], %[[C0]]
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
// CHECK-LABEL: func @shape_rankless // CHECK-LABEL: func @shape_rankless
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> { func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: [[TENSOR:%.+]] = "shape.to_extent_tensor"([[SHAPE]])
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32> %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
// CHECK: return [[CAST]]
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }

View File

@ -4610,45 +4610,16 @@ class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
LogicalResult matchAndRewrite(TF::ShapeOp op, LogicalResult matchAndRewrite(TF::ShapeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value input = op.input(); Value input = op.input();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
// If the shape is static it can be canonicalized.
if (!input_ty || input_ty.hasStaticShape()) {
return failure();
}
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
auto element_ty = result_ty.getElementType();
int64_t rank = input_ty.getRank();
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input); auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
auto index_ty = RankedTensorType::get({1}, element_ty); auto index_tensor =
llvm::SmallVector<Value, 4> dim_values; RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
for (int64_t i = 0; i < rank; ++i) { auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
if (!input_ty.isDynamicDim(i)) { op.getLoc(), index_tensor, shape_op);
auto dim_attr = DenseElementsAttr::get(
index_ty,
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
dim_values.push_back(index);
continue;
}
auto extent_op = rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
rewriter.create<shape::GetExtentOp>(op.getLoc(), shape_op, i);
auto index_op = rewriter.create<shape::SizeToIndexOp>(
op.getLoc(), rewriter.getIndexType(), extent_op);
auto int_op =
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
auto from_tensor = rewriter.create<TensorFromElementsOp>(
op.getLoc(), int_op.getResult());
auto reshape_op =
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
dim_values.push_back(reshape_op);
}
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
rewriter.getI64IntegerAttr(0));
return success(); return success();
} }
}; };