Rewrite graph with TensorList ops in support for dynamic RNN use case.
PiperOrigin-RevId: 256272543
This commit is contained in:
parent
9e0e23c59f
commit
05a4e4bf8c
@ -232,6 +232,7 @@ cc_library(
|
|||||||
"@local_config_mlir//:QuantOps",
|
"@local_config_mlir//:QuantOps",
|
||||||
"@local_config_mlir//:StandardOps",
|
"@local_config_mlir//:StandardOps",
|
||||||
"@local_config_mlir//:Support",
|
"@local_config_mlir//:Support",
|
||||||
|
"@local_config_mlir//:TypeUtilities",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
|
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
||||||
@ -32,25 +31,28 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10x
|
|||||||
return %2 : tensor<3x10xf32>
|
return %2 : tensor<3x10xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: tensorlistSetItem
|
// CHECK-LABEL: tensorlistSetItem
|
||||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
|
||||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
|
||||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
|
||||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
|
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
|
||||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
|
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
|
||||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
|
||||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
|
||||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
|
||||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
|
||||||
|
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
|
||||||
|
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
|
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
|
||||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
|
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
|
||||||
// CHECK: return %15 : tensor<3x10xf32>
|
// CHECK: return %15 : tensor<3x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,25 +64,28 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i
|
|||||||
return %2 : tensor<5xf32>
|
return %2 : tensor<5xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: tensorlistSetItemWithScalarElements
|
// CHECK-LABEL: tensorlistSetItemWithScalarElements
|
||||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
|
||||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
|
||||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
|
||||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
|
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
|
||||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
|
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
|
||||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
|
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
|
||||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
|
||||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
|
||||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
|
||||||
|
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||||
|
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
|
||||||
|
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
|
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
|
||||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
|
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
|
||||||
// CHECK: return %15 : tensor<5xf32>
|
// CHECK: return %15 : tensor<5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,10 +97,56 @@ func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xf32
|
|||||||
|
|
||||||
// CHECK-LABEL: tensorlistReserve
|
// CHECK-LABEL: tensorlistReserve
|
||||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
|
||||||
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||||
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||||
|
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||||
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
|
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
|
||||||
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
|
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
|
||||||
// CHECK: return %3 : tensor<3xf32>
|
// CHECK: return %3 : tensor<3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
|
||||||
|
^bb0(%arg0: tensor<2x3xf32>):
|
||||||
|
%cst = constant dense<3> : tensor<1xi32>
|
||||||
|
%cst_0 = constant dense<0> : tensor<i32>
|
||||||
|
%cst_1 = constant dense<-1> : tensor<i32>
|
||||||
|
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant>
|
||||||
|
%1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>)
|
||||||
|
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant>, tensor<i32>) -> tensor<*xf32>
|
||||||
|
return %2 : tensor<*xf32>
|
||||||
|
|
||||||
|
// make sure the variant types in input/output have been updated, and `T` attribute
|
||||||
|
// is removed.
|
||||||
|
// CHECK-LABEL: func @tensorlistWhileLoop
|
||||||
|
// CHECK-NOT: "tf.While"{{.*}}T =
|
||||||
|
// CHECK: "tf.While"
|
||||||
|
// CHECK-SAME: (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||||
|
// CHECK: return %0#1 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @tensorlistWhileBody(tensor<*xi32>, tensor<*x!tf.variant>) -> (tensor<*xi32>, tensor<*x!tf.variant>) {
|
||||||
|
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
|
||||||
|
%cst = constant dense<1> : tensor<i32>
|
||||||
|
%0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
%1 = "tf.Identity"(%arg1) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant>
|
||||||
|
return %0, %1 : tensor<*xi32>, tensor<*x!tf.variant>
|
||||||
|
|
||||||
|
// verify `body` function's signature.
|
||||||
|
// CHECK: func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||||
|
// CHECK: %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
// CHECK-NOT: tensor<*x!tf.variant>
|
||||||
|
// CHECK: %1 = "tf.Identity"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @tensorlistWhileCond(tensor<*xi32>, tensor<*x!tf.variant>) -> tensor<*xi1> {
|
||||||
|
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
|
||||||
|
%cst = constant dense<2> : tensor<i32>
|
||||||
|
%0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||||
|
return %0 : tensor<*xi1>
|
||||||
|
|
||||||
|
// verify `cond` function's signature.
|
||||||
|
// CHECK: func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xi1>
|
||||||
|
// CHECK: %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||||
|
// CHECK: return %0 : tensor<*xi1>
|
||||||
|
}
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||||
@ -42,11 +43,13 @@ limitations under the License.
|
|||||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||||
|
|
||||||
@ -56,11 +59,30 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Lower TensorList ops in functions for subsequent legalization.
|
class TensorListPatternRewriter : public PatternRewriter {
|
||||||
|
public:
|
||||||
|
explicit TensorListPatternRewriter(Function fn)
|
||||||
|
: PatternRewriter(fn.getBody()) {}
|
||||||
|
|
||||||
|
Operation *createOperation(const OperationState &state) override {
|
||||||
|
return OpBuilder::createOperation(state);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Lower TensorList ops in functions for subsequent legalization.
|
||||||
|
// TODO(haoliang): Use DialectConversion infra to simplify the rewriting
|
||||||
|
// process.
|
||||||
struct LowerStaticTensorListPass
|
struct LowerStaticTensorListPass
|
||||||
: public FunctionPass<LowerStaticTensorListPass> {
|
: public ModulePass<LowerStaticTensorListPass> {
|
||||||
void runOnFunction() override;
|
void runOnModule() override;
|
||||||
LogicalResult ModifyTensorList();
|
|
||||||
|
// Apply type and op changes within a function.
|
||||||
|
LogicalResult RewriteFunction(Function func,
|
||||||
|
TensorListPatternRewriter *rewriter);
|
||||||
|
|
||||||
|
// Changes the function type of `cond_func` and `body_func`, and the result
|
||||||
|
// type of the `WhileOp`.
|
||||||
|
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
|
||||||
};
|
};
|
||||||
|
|
||||||
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
||||||
@ -121,9 +143,9 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
|||||||
// Calculate the first dimension, which is index + 1.
|
// Calculate the first dimension, which is index + 1.
|
||||||
auto index = tf_op.index();
|
auto index = tf_op.index();
|
||||||
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
||||||
auto begin =
|
auto begin = rewriter.create<TF::AddOp>(
|
||||||
rewriter.create<TF::AddOp>(op->getLoc(), vector_type, index,
|
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
|
||||||
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
||||||
|
|
||||||
// Followed by the first dimension `begin`, are `item_rank` of 0s.
|
// Followed by the first dimension `begin`, are `item_rank` of 0s.
|
||||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||||
@ -198,7 +220,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
|
|||||||
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
|
||||||
// Note that the first item of the shape array is the element's rank, add
|
// Note that the first item of the shape array is the element's rank, add
|
||||||
// it by 1 to get the input's rank.
|
// it by 1 to get the input's rank.
|
||||||
if (type.hasStaticShape()) {
|
if (type.hasStaticShape() && type.getRank() != 0) {
|
||||||
input_rank = type.getShape()[0] + 1;
|
input_rank = type.getShape()[0] + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,98 +258,149 @@ namespace {
|
|||||||
} // namespace
|
} // namespace
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
||||||
LogicalResult LowerStaticTensorListPass::ModifyTensorList() {
|
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
||||||
// In `runOnFunction`, there is no guarantee about
|
TF::WhileOp *while_op) {
|
||||||
// in which order those patterns will be applied. Our transformation requires
|
SmallVector<Type, 8> unranked_argument_types;
|
||||||
// that at runtime each `TensorListSetItem` op takes in a normal tensor type
|
for (const auto &operand : while_op->getOperands()) {
|
||||||
// rather than a `DT_VARIANT` tensor. So here we need to manually walk-through
|
unranked_argument_types.push_back(
|
||||||
// the IR and change the argument/return types of each `TensorListSetItemOp`.
|
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||||
// TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/
|
}
|
||||||
// produces `DT_VARIANT` tensor. 2) More robust support for handling multiple
|
|
||||||
// different tensorlist types. For example, consider the case like:
|
auto *context = &getContext();
|
||||||
// l1 = list_ops.tensor_list_from_tensor(t, element_shape1)
|
auto module = getModule();
|
||||||
// l2 = list_ops.tensor_list_from_tensor(t, element_shape2)
|
Function cond_func = module.getNamedFunction(while_op->getCond());
|
||||||
// l1 = list_ops.tensor_list_set_item(l1, 0, item1)
|
Function body_func = module.getNamedFunction(while_op->getBody());
|
||||||
// l2 = list_ops.tensor_list_set_item(l2, 0, item2)
|
|
||||||
// 3) Handle the case where a tensorlist output is passed to multiple
|
if (cond_func) {
|
||||||
// functions.
|
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
||||||
for (Block &block : getFunction()) {
|
cond_func.setType(FunctionType::get(
|
||||||
Type tensor_type;
|
unranked_argument_types, cond_func.getType().getResults(), context));
|
||||||
|
// Change the argument type for the first block.
|
||||||
|
Block &cond_first_bb = cond_func.front();
|
||||||
|
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
|
||||||
|
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (body_func) {
|
||||||
|
SmallVector<Type, 8> updated_result_types;
|
||||||
|
for (int i = 0; i < body_func.getType().getNumResults(); ++i) {
|
||||||
|
auto result_type = body_func.getType().getResult(i);
|
||||||
|
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
|
||||||
|
// For variant type, use the corresponding unranked type.
|
||||||
|
updated_result_types.push_back(unranked_argument_types[i]);
|
||||||
|
} else {
|
||||||
|
updated_result_types.push_back(result_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Change `body_func`'s argument type to `unranked_argument_types`. If it
|
||||||
|
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||||
|
// derived from the corresponding argument.
|
||||||
|
body_func.setType(FunctionType::get(unranked_argument_types,
|
||||||
|
updated_result_types, context));
|
||||||
|
// Change the argument type for the first block.
|
||||||
|
Block &body_first_bb = body_func.front();
|
||||||
|
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||||
|
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < while_op->getNumOperands(); ++i) {
|
||||||
|
auto operand = while_op->getOperand(i);
|
||||||
|
auto result = while_op->getResult(i);
|
||||||
|
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
|
||||||
|
// If we notice the result type is a DT_VARIANT, we change the
|
||||||
|
// corresponding result type to unranked tensor type.
|
||||||
|
result->setType(
|
||||||
|
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||||
|
Function func, TensorListPatternRewriter *rewriter) {
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
for (Block &block : func) {
|
||||||
|
// Buffer the op pointers inside the current block into a vector, since
|
||||||
|
// the block iterator might be invalidated if we rewrite ops during looping.
|
||||||
|
std::vector<Operation *> ops_in_block;
|
||||||
for (Operation &op : block) {
|
for (Operation &op : block) {
|
||||||
|
ops_in_block.push_back(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation *op : ops_in_block) {
|
||||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
|
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
|
||||||
tensor_type = tf_op.tensor()->getType();
|
auto c = TFL::ConvertTFTensorListFromTensor(context);
|
||||||
|
rewriter->setInsertionPoint(op);
|
||||||
|
c.matchAndRewrite(op, *rewriter);
|
||||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||||
tf_op.element_dtype().isF64() ||
|
tf_op.element_dtype().isF64() ||
|
||||||
tf_op.element_dtype().isa<IntegerType>())) {
|
tf_op.element_dtype().isInteger(8) ||
|
||||||
|
tf_op.element_dtype().isInteger(16) ||
|
||||||
|
tf_op.element_dtype().isInteger(32) ||
|
||||||
|
tf_op.element_dtype().isInteger(64))) {
|
||||||
return tf_op.emitError(
|
return tf_op.emitError(
|
||||||
"requires element_dtype to be integer or 16-bit/32-bit/64-bit "
|
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
|
||||||
|
"or 16-bit/32-bit/64-bit "
|
||||||
"float type during TF Lite transformation pass");
|
"float type during TF Lite transformation pass");
|
||||||
}
|
}
|
||||||
// TODO(haoliang): figure out better way of specify shape.
|
auto c = ConvertTFTensorListReserve(context);
|
||||||
tensor_type = UnrankedTensorType::get(tf_op.element_dtype());
|
rewriter->setInsertionPoint(op);
|
||||||
}
|
c.matchAndRewrite(op, *rewriter);
|
||||||
|
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
|
||||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
auto c = TFL::ConvertTFTensorListGetItem(context);
|
||||||
tf_op.input_handle()->setType(tensor_type);
|
rewriter->setInsertionPoint(op);
|
||||||
tf_op.getResult()->setType(tensor_type);
|
c.matchAndRewrite(op, *rewriter);
|
||||||
}
|
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||||
// Currently we will raise an error if an op other than the following
|
auto c = ConvertTFTensorListSetItem(context);
|
||||||
// contains a DT_VARIANT tensor as its input or output. Below ops already
|
rewriter->setInsertionPoint(op);
|
||||||
// have proper transformation patterns that eliminate the need of
|
c.matchAndRewrite(op, *rewriter);
|
||||||
// `DT_VARIANT`, we consider it's safe to not raise an error on those ops.
|
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
|
||||||
if (llvm::isa<TF::TensorListFromTensorOp>(op) ||
|
auto c = TFL::ConvertTFTensorListStack(context);
|
||||||
llvm::isa<TF::TensorListReserveOp>(op) ||
|
rewriter->setInsertionPoint(op);
|
||||||
llvm::isa<TF::TensorListSetItemOp>(op) ||
|
c.matchAndRewrite(op, *rewriter);
|
||||||
llvm::isa<TF::TensorListStackOp>(op) ||
|
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||||
llvm::isa<TF::TensorListGetItemOp>(op)) {
|
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||||
continue;
|
UpdateWhileFunctionType(&tf_op);
|
||||||
}
|
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||||
// Check if any of the input operand is a DT_VARIANT.
|
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||||
for (Type type : op.getOperandTypes()) {
|
tf_op.getResult()->setType(tf_op.getOperand()->getType());
|
||||||
if (type.isa<TF::VariantType>()) {
|
|
||||||
return op.emitError(
|
|
||||||
"op's input contains a DT_VARIANT tensor. Currently we only "
|
|
||||||
"allow "
|
|
||||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
|
||||||
"TensorListSetItem/"
|
|
||||||
"TensorListGetItem to have DT_VARIANT input/output");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check if any of the output is a DT_VARIANT.
|
|
||||||
for (Type type : op.getResultTypes()) {
|
|
||||||
if (type.isa<TF::VariantType>()) {
|
|
||||||
return op.emitError(
|
|
||||||
"op's output contains a DT_VARIANT tensor. Currently we only "
|
|
||||||
"allow "
|
|
||||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
|
||||||
"TensorListSetItem/"
|
|
||||||
"TensorListGetItem to have DT_VARIANT input/output");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void LowerStaticTensorListPass::runOnFunction() {
|
void LowerStaticTensorListPass::runOnModule() {
|
||||||
if (failed(ModifyTensorList())) {
|
// TODO(haoliang): currently we process the `main` function first, and the
|
||||||
signalPassFailure();
|
// remaining functions may be processed in arbitrary order. However, this will
|
||||||
return;
|
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
||||||
|
// before the function that produces the `DT_VARIANT`. We need to carefully
|
||||||
|
// order the functions to be processed.
|
||||||
|
std::vector<Function> funcs_in_module;
|
||||||
|
for (auto func : getModule().getFunctions()) {
|
||||||
|
// Always place the main function to be the first in the list.
|
||||||
|
if (func.getName().is("main")) {
|
||||||
|
funcs_in_module.insert(funcs_in_module.begin(), func);
|
||||||
|
} else {
|
||||||
|
funcs_in_module.push_back(func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto func : funcs_in_module) {
|
||||||
|
TensorListPatternRewriter rewriter(func);
|
||||||
|
if (failed(RewriteFunction(func, &rewriter))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
OwningRewritePatternList patterns;
|
|
||||||
auto func = getFunction();
|
|
||||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
|
||||||
patterns.push_back(
|
|
||||||
llvm::make_unique<ConvertTFTensorListReserve>(&getContext()));
|
|
||||||
patterns.push_back(
|
|
||||||
llvm::make_unique<ConvertTFTensorListSetItem>(&getContext()));
|
|
||||||
applyPatternsGreedily(func, std::move(patterns));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||||
// pass.
|
/// pass.
|
||||||
FunctionPassBase *TFL::CreateLowerStaticTensorListPass() {
|
ModulePassBase *TFL::CreateLowerStaticTensorListPass() {
|
||||||
return new LowerStaticTensorListPass();
|
return new LowerStaticTensorListPass();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class FunctionPassBase;
|
class FunctionPassBase;
|
||||||
|
class ModulePassBase;
|
||||||
|
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ FunctionPassBase *CreatePrepareTFPass();
|
|||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||||
// pass.
|
// pass.
|
||||||
FunctionPassBase *CreateLowerStaticTensorListPass();
|
ModulePassBase *CreateLowerStaticTensorListPass();
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||||
FunctionPassBase *CreateQuantizePass();
|
FunctionPassBase *CreateQuantizePass();
|
||||||
|
@ -23,11 +23,14 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
|||||||
// into regular tensors. We also assume that each element in the `TensorList` has
|
// into regular tensors. We also assume that each element in the `TensorList` has
|
||||||
// a same constant shape.
|
// a same constant shape.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
|
def ConvertTFTensorListFromTensor : Pat<
|
||||||
(replaceWithValue $tensor)>;
|
(TF_TensorListFromTensorOp $tensor, $element_shape),
|
||||||
|
(replaceWithValue $tensor)>;
|
||||||
|
|
||||||
def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
def ConvertTFTensorListStack : Pat<
|
||||||
(replaceWithValue $input)>;
|
(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
||||||
|
(replaceWithValue $input)>;
|
||||||
|
|
||||||
def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape),
|
def ConvertTFTensorListGetItem : Pat<
|
||||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||||
|
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||||
|
Loading…
Reference in New Issue
Block a user