Fix LowerInvertPermutationOp
<5xi32> to <5x1xi32> is a reshape, not a transpose PiperOrigin-RevId: 306531112 Change-Id: I1e5541bc43997eda222837691bcbad7107f57982
This commit is contained in:
parent
9d53fd3a01
commit
ce55348ee9
@ -3,8 +3,8 @@
|
||||
// CHECK-LABEL: invert_permutation
|
||||
func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> {
|
||||
// CHECK-NEXT: %[[UPDATES:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32>
|
||||
// CHECK-NEXT: %[[PERM:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[INDICES:.*]] = "tf.Transpose"(%arg0, %[[PERM]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
|
||||
// CHECK-NEXT: "tf.TensorScatterUpdate"(%arg0, %[[INDICES]], %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
%0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
|
@ -253,8 +253,8 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
|
||||
// %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// %updates = "tf.Range"(%start, %limit, %delta) :
|
||||
// (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
|
||||
// %perm = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>}
|
||||
// %indices = "tf.Transpose"(%x, %perm) : (tensor<5xi32, tensor<2xi32) ->
|
||||
// %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
|
||||
// %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
|
||||
// tensor<5x1xi32>
|
||||
// "tf.TensorScatterUpdate"(%x, %indices, %updates) :
|
||||
// (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
@ -268,13 +268,12 @@ class LowerInvertPermutationOp
|
||||
LogicalResult matchAndRewrite(TF::InvertPermutationOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto x_type = op.x().getType().cast<TensorType>();
|
||||
Type int_type = x_type.getElementType(); // Could be i32 or i64.
|
||||
|
||||
auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
|
||||
// x input must have static shape.
|
||||
if (!x_type.hasStaticShape()) {
|
||||
if (!x_type || !x_type.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
Type int_type = x_type.getElementType(); // Could be i32 or i64.
|
||||
|
||||
auto result_type = x_type;
|
||||
auto start =
|
||||
@ -287,13 +286,11 @@ class LowerInvertPermutationOp
|
||||
auto updates =
|
||||
rewriter.create<TF::RangeOp>(loc, result_type, start, limit, delta);
|
||||
|
||||
auto perm_type = RankedTensorType::get({2}, int_type);
|
||||
auto perm = rewriter.create<TF::ConstOp>(
|
||||
loc, DenseElementsAttr::get(perm_type, {1, 0}));
|
||||
auto transposed_x_type =
|
||||
RankedTensorType::get({x_type.getShape()[0], 1}, int_type);
|
||||
auto indices =
|
||||
rewriter.create<TF::TransposeOp>(loc, transposed_x_type, op.x(), perm);
|
||||
auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
|
||||
auto shape = rewriter.create<TF::ConstOp>(
|
||||
loc, DenseElementsAttr::get(
|
||||
shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
|
||||
auto indices = rewriter.create<TF::ReshapeOp>(loc, op.x(), shape);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::TensorScatterUpdateOp>(
|
||||
op, result_type, op.x(), indices, updates);
|
||||
|
Loading…
Reference in New Issue
Block a user