RankOp folder for RankedTensorType

PiperOrigin-RevId: 308362479
Change-Id: Ib4c87c0cfbe31a5cd6dcd7e0bb0b81b050769d7d
This commit is contained in:
Robert Suderman 2020-04-24 18:15:58 -07:00 committed by TensorFlower Gardener
parent 947bb83ce2
commit c942431b49
4 changed files with 34 additions and 22 deletions

View File

@ -289,8 +289,8 @@ func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor<i32>) {
%2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor<i32>
return %2 : tensor<i32>
// CHECK: %[[R:.*]] = "tf.Rank"(%arg0)
// CHECK-NEXT: return %[[R]] : tensor<i32>
// CHECK: %[[R:.*]] = constant dense<2>
// CHECK: return %cst : tensor<i32>
}
// CHECK-LABEL: fakeQuantWithConv2D
@ -418,14 +418,10 @@ func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf
return %166 : tensor<1x1000xf32>
// CHECK-LABEL: matmulNoTransposeAOrB
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: %1 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %2 = "tf.MatMul"(%arg0, %1) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
// CHECK: return %2 : tensor<1x1000xf32>
}
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
@ -433,18 +429,12 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>
return %166 : tensor<1x1000xf32>
// CHECK-LABEL: matmulNoTransposeB
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
// CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %2 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %3 = "tf.MatMul"(%1, %2) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
// CHECK: return %3 : tensor<1x1000xf32>
}
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {

View File

@ -5898,6 +5898,8 @@ of the tensor. Rank is also known as "order", "degree", or "ndims."
let builders = [
OpBuilder<"Builder* builder, OperationState& result, Value input">
];
let hasFolder = 1;
}
def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> {

View File

@ -2349,6 +2349,17 @@ void RankOp::build(Builder *builder, OperationState &result, Value input) {
input);
}
// This will create a constant value for RankOp of a ranked tensor.
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
auto type = input().getType();
auto ranked_type = type.dyn_cast<RankedTensorType>();
if (!ranked_type) return {};
auto output_type = getType().cast<ShapedType>();
int32_t rank = ranked_type.getRank();
return DenseIntElementsAttr::get(output_type, rank);
}
//===----------------------------------------------------------------------===//
// RealDivOp
//===----------------------------------------------------------------------===//

View File

@ -462,3 +462,12 @@ func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: return %1
}
// CHECK-LABEL: testRankOfRankedTensor
func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
// CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor<i32>}
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<i32>
// CHECK: return [[VAL0]]
return %0 : tensor<i32>
}