RankOp folder for RankedTensorType
PiperOrigin-RevId: 308362479 Change-Id: Ib4c87c0cfbe31a5cd6dcd7e0bb0b81b050769d7d
This commit is contained in:
parent
947bb83ce2
commit
c942431b49
@ -289,8 +289,8 @@ func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor<i32>) {
|
||||
%2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
|
||||
// CHECK: %[[R:.*]] = "tf.Rank"(%arg0)
|
||||
// CHECK-NEXT: return %[[R]] : tensor<i32>
|
||||
// CHECK: %[[R:.*]] = constant dense<2>
|
||||
// CHECK: return %cst : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithConv2D
|
||||
@ -418,14 +418,10 @@ func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf
|
||||
return %166 : tensor<1x1000xf32>
|
||||
|
||||
// CHECK-LABEL: matmulNoTransposeAOrB
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: %1 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %2 = "tf.MatMul"(%arg0, %1) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
// CHECK: return %2 : tensor<1x1000xf32>
|
||||
}
|
||||
|
||||
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
|
||||
@ -433,18 +429,12 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>
|
||||
return %166 : tensor<1x1000xf32>
|
||||
|
||||
// CHECK-LABEL: matmulNoTransposeB
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
||||
// CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %2 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %3 = "tf.MatMul"(%1, %2) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||
// CHECK: return %3 : tensor<1x1000xf32>
|
||||
|
||||
}
|
||||
|
||||
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
||||
|
@ -5898,6 +5898,8 @@ of the tensor. Rank is also known as "order", "degree", or "ndims."
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState& result, Value input">
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> {
|
||||
|
@ -2349,6 +2349,17 @@ void RankOp::build(Builder *builder, OperationState &result, Value input) {
|
||||
input);
|
||||
}
|
||||
|
||||
// This will create a constant value for RankOp of a ranked tensor.
|
||||
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto type = input().getType();
|
||||
auto ranked_type = type.dyn_cast<RankedTensorType>();
|
||||
if (!ranked_type) return {};
|
||||
|
||||
auto output_type = getType().cast<ShapedType>();
|
||||
int32_t rank = ranked_type.getRank();
|
||||
return DenseIntElementsAttr::get(output_type, rank);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RealDivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -462,3 +462,12 @@ func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -
|
||||
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRankOfRankedTensor
|
||||
func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
||||
// CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor<i32>}
|
||||
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<i32>
|
||||
|
||||
// CHECK: return [[VAL0]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user