Lower tf.Const op with TensorList in tfl-lower-static-tensor-list pass

tf.Const op is lowered to one Const op for each of the element in the TensorList which are then combined using tf.Pack op.

PiperOrigin-RevId: 292073651
Change-Id: I7eb0fb0864b03ccadd83bb50ae7a469fa21a0196
This commit is contained in:
Smit Hinsu 2020-01-28 21:29:09 -08:00 committed by TensorFlower Gardener
parent 138d4a0a58
commit 941950947d
4 changed files with 111 additions and 5 deletions

View File

@ -295,13 +295,16 @@ cc_library(
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",

View File

@ -1,5 +1,26 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure // RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: tensorlistConst
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
// CHECK: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
// CHECK: return %[[LIST]]
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
// CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
// CHECK: return %[[LIST]]
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
return %1 : tensor<0x3xi32>
}
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) { func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>> %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32> %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>

View File

@ -80,10 +80,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
} }
if (pass_config.lower_tensor_list_ops) { if (pass_config.lower_tensor_list_ops) {
// Execute this pass before `CanonicalizerPass` in case some TensorList
// ops are constant folded into variant types.
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
// handle constant ops that produce `TensorList`.
// TODO(haoliang): Add this pass by default. // TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
} }

View File

@ -23,9 +23,11 @@ limitations under the License.
#include <climits> #include <climits>
#include <cstdint> #include <cstdint>
#include "absl/container/inlined_vector.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h" #include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@ -57,6 +59,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/tensor_list.h"
#define DEBUG_TYPE "tf-tfl-legalization" #define DEBUG_TYPE "tf-tfl-legalization"
@ -162,6 +168,86 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size); start_position, slice_size);
} }
// Converts tf.Const containing variant of type TensorList to a tensor of
// primitive element types. Each of the individual tensor in the list is
// converted to an ElementsAttr and then those are packed together using
// tf.Pack op.
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
TF::ConstOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Verify that the opaque elements attribute contains tensor of type variant
// and scalar shape. The variant type should hold a TensorList.
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
if (!opaque_attr) return matchFailure();
tensorflow::Tensor tensor;
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
return matchFailure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
return matchFailure();
const tensorflow::TensorList *list =
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
if (!list) return matchFailure();
// Verify output type is variant and contains exactly one ranked subtypes.
auto variant_ty =
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
if (!variant_ty) return matchFailure();
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
if (subtypes.size() != 1) return matchFailure();
RankedTensorType list_element_ty =
subtypes.front().dyn_cast<RankedTensorType>();
if (!list_element_ty) return matchFailure();
// Extract tensor elements for the TensorList and construct result type
// based on the number of elements and element shape.
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
llvm::SmallVector<int64_t, 4> result_shape = {
static_cast<int64_t>(tensors.size())};
result_shape.append(list_element_ty.getShape().begin(),
list_element_ty.getShape().end());
auto result_ty =
RankedTensorType::get(result_shape, list_element_ty.getElementType());
// If the list is empty, directly create the final result instead of
// creating the tf.Pack op. tf.Pack op requires at least one operand.
if (tensors.empty()) {
absl::InlinedVector<tensorflow::int64, 4> tf_shape;
tf_shape.reserve(result_shape.size());
for (int64_t dim : result_shape) {
tf_shape.push_back(dim);
}
tensorflow::Tensor tensor(list->element_dtype,
tensorflow::TensorShape(tf_shape));
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
return matchSuccess();
}
// Extract individual tensor list element and combine them using the tf.Pack
// op.
Location loc = op.getLoc();
llvm::SmallVector<Value, 4> values;
values.reserve(tensors.size());
for (const tensorflow::Tensor &tensor : tensors) {
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
values.push_back(value);
}
rewriter.replaceOpWithNewOp<TF::PackOp>(
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
return matchSuccess();
}
};
struct ConvertTensorListSetItem struct ConvertTensorListSetItem
: public OpConversionPattern<TF::TensorListSetItemOp> { : public OpConversionPattern<TF::TensorListSetItemOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -768,7 +854,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns patterns
.insert<ConvertEmptyTensorList, ConvertIdentity, .insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
ConvertTensorListFromTensor, ConvertTensorListGetItem, ConvertTensorListFromTensor, ConvertTensorListGetItem,
ConvertTensorListLength, ConvertTensorListPushBack, ConvertTensorListLength, ConvertTensorListPushBack,
ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListReserve, ConvertTensorListSetItem,