From c942431b499bd96ee9cd3be6c2910e0710cc5d48 Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Fri, 24 Apr 2020 18:15:58 -0700 Subject: [PATCH] RankOp folder for RankedTensorType PiperOrigin-RevId: 308362479 Change-Id: Ib4c87c0cfbe31a5cd6dcd7e0bb0b81b050769d7d --- .../compiler/mlir/lite/tests/prepare-tf.mlir | 34 +++++++------------ .../mlir/tensorflow/ir/tf_generated_ops.td | 2 ++ .../compiler/mlir/tensorflow/ir/tf_ops.cc | 11 ++++++ .../mlir/tensorflow/tests/canonicalize.mlir | 9 +++++ 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 5e456b1a7e5..3af0b25a8e3 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -289,8 +289,8 @@ func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor) { %2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor return %2 : tensor -// CHECK: %[[R:.*]] = "tf.Rank"(%arg0) -// CHECK-NEXT: return %[[R]] : tensor +// CHECK: %[[R:.*]] = constant dense<2> +// CHECK: return %cst : tensor } // CHECK-LABEL: fakeQuantWithConv2D @@ -418,14 +418,10 @@ func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf return %166 : tensor<1x1000xf32> // CHECK-LABEL: matmulNoTransposeAOrB - // CHECK: %cst = constant dense<0> : tensor - // CHECK: %cst_0 = constant dense<-1> : tensor - // CHECK: %cst_1 = constant dense<1> : tensor - // CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor - // CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: %1 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> + // CHECK: %2 = "tf.MatMul"(%arg0, %1) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: return %2 : tensor<1x1000xf32> } func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> { @@ -433,18 +429,12 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32> return %166 : tensor<1x1000xf32> // CHECK-LABEL: matmulNoTransposeB - // CHECK: %cst = constant dense<0> : tensor - // CHECK: %cst_0 = constant dense<-1> : tensor - // CHECK: %cst_1 = constant dense<1> : tensor - // CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor - // CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor) -> tensor<*xf32> - // CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor - // CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x1280xf32>, tensor) -> tensor<*xf32> + // CHECK: %2 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> + // CHECK: %3 = "tf.MatMul"(%1, %2) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: return %3 : tensor<1x1000xf32> + } func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d14be1cfd0f..59b6b206513 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -5898,6 +5898,8 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." let builders = [ OpBuilder<"Builder* builder, OperationState& result, Value input"> ]; + + let hasFolder = 1; } def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index ba06e311524..f61b5edc45d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2349,6 +2349,17 @@ void RankOp::build(Builder *builder, OperationState &result, Value input) { input); } +// This will create a constant value for RankOp of a ranked tensor. +OpFoldResult RankOp::fold(ArrayRef operands) { + auto type = input().getType(); + auto ranked_type = type.dyn_cast(); + if (!ranked_type) return {}; + + auto output_type = getType().cast(); + int32_t rank = ranked_type.getRank(); + return DenseIntElementsAttr::get(output_type, rank); +} + //===----------------------------------------------------------------------===// // RealDivOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 7f362a19e04..18f8d5f4486 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -462,3 +462,12 @@ func @testMultiReadVariableOpsOfCast(%arg0: tensor>>) - // CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor // CHECK: return %1 } + +// CHECK-LABEL: testRankOfRankedTensor +func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor { + // CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor} + %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor + + // CHECK: return [[VAL0]] + return %0 : tensor +}