RankOp folder for RankedTensorType
PiperOrigin-RevId: 308362479 Change-Id: Ib4c87c0cfbe31a5cd6dcd7e0bb0b81b050769d7d
This commit is contained in:
parent
947bb83ce2
commit
c942431b49
@ -289,8 +289,8 @@ func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor<i32>) {
|
|||||||
%2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor<i32>
|
%2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor<i32>
|
||||||
return %2 : tensor<i32>
|
return %2 : tensor<i32>
|
||||||
|
|
||||||
// CHECK: %[[R:.*]] = "tf.Rank"(%arg0)
|
// CHECK: %[[R:.*]] = constant dense<2>
|
||||||
// CHECK-NEXT: return %[[R]] : tensor<i32>
|
// CHECK: return %cst : tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: fakeQuantWithConv2D
|
// CHECK-LABEL: fakeQuantWithConv2D
|
||||||
@ -418,14 +418,10 @@ func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf
|
|||||||
return %166 : tensor<1x1000xf32>
|
return %166 : tensor<1x1000xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: matmulNoTransposeAOrB
|
// CHECK-LABEL: matmulNoTransposeAOrB
|
||||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
// CHECK: %1 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
// CHECK: %2 = "tf.MatMul"(%arg0, %1) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||||
// CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
// CHECK: return %2 : tensor<1x1000xf32>
|
||||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
|
||||||
// CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
|
func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> {
|
||||||
@ -433,18 +429,12 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>
|
|||||||
return %166 : tensor<1x1000xf32>
|
return %166 : tensor<1x1000xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: matmulNoTransposeB
|
// CHECK-LABEL: matmulNoTransposeB
|
||||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
// CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||||
// CHECK: %cst_0 = constant dense<-1> : tensor<i32>
|
// CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %cst_1 = constant dense<1> : tensor<i32>
|
// CHECK: %2 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor<i32>
|
// CHECK: %3 = "tf.MatMul"(%1, %2) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||||
// CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: return %3 : tensor<1x1000xf32>
|
||||||
// CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor<?xi32>) -> tensor<*xf32>
|
|
||||||
// CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor<i32>
|
|
||||||
// CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
|
||||||
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
||||||
|
@ -5898,6 +5898,8 @@ of the tensor. Rank is also known as "order", "degree", or "ndims."
|
|||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder* builder, OperationState& result, Value input">
|
OpBuilder<"Builder* builder, OperationState& result, Value input">
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> {
|
def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> {
|
||||||
|
@ -2349,6 +2349,17 @@ void RankOp::build(Builder *builder, OperationState &result, Value input) {
|
|||||||
input);
|
input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This will create a constant value for RankOp of a ranked tensor.
|
||||||
|
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto type = input().getType();
|
||||||
|
auto ranked_type = type.dyn_cast<RankedTensorType>();
|
||||||
|
if (!ranked_type) return {};
|
||||||
|
|
||||||
|
auto output_type = getType().cast<ShapedType>();
|
||||||
|
int32_t rank = ranked_type.getRank();
|
||||||
|
return DenseIntElementsAttr::get(output_type, rank);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// RealDivOp
|
// RealDivOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -462,3 +462,12 @@ func @testMultiReadVariableOpsOfCast(%arg0: tensor<!tf.resource<tensor<f32>>>) -
|
|||||||
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
// CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||||
// CHECK: return %1
|
// CHECK: return %1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testRankOfRankedTensor
|
||||||
|
func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
||||||
|
// CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor<i32>}
|
||||||
|
%0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor<i32>
|
||||||
|
|
||||||
|
// CHECK: return [[VAL0]]
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user