Lower tf.Const op with TensorList in tfl-lower-static-tensor-list pass
tf.Const op is lowered to one Const op for each of the element in the TensorList which are then combined using tf.Pack op. PiperOrigin-RevId: 292073651 Change-Id: I7eb0fb0864b03ccadd83bb50ae7a469fa21a0196
This commit is contained in:
parent
138d4a0a58
commit
941950947d
@ -295,13 +295,16 @@ cc_library(
|
|||||||
":validators",
|
":validators",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/kernels:tensor_list",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
|
@ -1,5 +1,26 @@
|
|||||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
|
// CHECK-LABEL: tensorlistConst
|
||||||
|
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
|
||||||
|
// CHECK: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||||
|
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||||
|
|
||||||
|
// CHECK: return %[[LIST]]
|
||||||
|
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||||
|
return %1 : tensor<2x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
|
||||||
|
// CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
|
||||||
|
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||||
|
|
||||||
|
// CHECK: return %[[LIST]]
|
||||||
|
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
|
||||||
|
return %1 : tensor<0x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
||||||
|
@ -80,10 +80,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pass_config.lower_tensor_list_ops) {
|
if (pass_config.lower_tensor_list_ops) {
|
||||||
// Execute this pass before `CanonicalizerPass` in case some TensorList
|
|
||||||
// ops are constant folded into variant types.
|
|
||||||
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
|
|
||||||
// handle constant ops that produce `TensorList`.
|
|
||||||
// TODO(haoliang): Add this pass by default.
|
// TODO(haoliang): Add this pass by default.
|
||||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||||
}
|
}
|
||||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
|||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/None.h"
|
#include "llvm/ADT/None.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
@ -57,6 +59,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/tensor_list.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||||
|
|
||||||
@ -162,6 +168,86 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
|||||||
start_position, slice_size);
|
start_position, slice_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts tf.Const containing variant of type TensorList to a tensor of
|
||||||
|
// primitive element types. Each of the individual tensor in the list is
|
||||||
|
// converted to an ElementsAttr and then those are packed together using
|
||||||
|
// tf.Pack op.
|
||||||
|
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(
|
||||||
|
TF::ConstOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
// Verify that the opaque elements attribute contains tensor of type variant
|
||||||
|
// and scalar shape. The variant type should hold a TensorList.
|
||||||
|
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
|
||||||
|
if (!opaque_attr) return matchFailure();
|
||||||
|
tensorflow::Tensor tensor;
|
||||||
|
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
|
||||||
|
return matchFailure();
|
||||||
|
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
|
||||||
|
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
const tensorflow::TensorList *list =
|
||||||
|
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
|
||||||
|
if (!list) return matchFailure();
|
||||||
|
|
||||||
|
// Verify output type is variant and contains exactly one ranked subtypes.
|
||||||
|
auto variant_ty =
|
||||||
|
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
|
||||||
|
if (!variant_ty) return matchFailure();
|
||||||
|
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
|
||||||
|
if (subtypes.size() != 1) return matchFailure();
|
||||||
|
RankedTensorType list_element_ty =
|
||||||
|
subtypes.front().dyn_cast<RankedTensorType>();
|
||||||
|
if (!list_element_ty) return matchFailure();
|
||||||
|
|
||||||
|
// Extract tensor elements for the TensorList and construct result type
|
||||||
|
// based on the number of elements and element shape.
|
||||||
|
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
|
||||||
|
llvm::SmallVector<int64_t, 4> result_shape = {
|
||||||
|
static_cast<int64_t>(tensors.size())};
|
||||||
|
result_shape.append(list_element_ty.getShape().begin(),
|
||||||
|
list_element_ty.getShape().end());
|
||||||
|
auto result_ty =
|
||||||
|
RankedTensorType::get(result_shape, list_element_ty.getElementType());
|
||||||
|
|
||||||
|
// If the list is empty, directly create the final result instead of
|
||||||
|
// creating the tf.Pack op. tf.Pack op requires at least one operand.
|
||||||
|
if (tensors.empty()) {
|
||||||
|
absl::InlinedVector<tensorflow::int64, 4> tf_shape;
|
||||||
|
tf_shape.reserve(result_shape.size());
|
||||||
|
for (int64_t dim : result_shape) {
|
||||||
|
tf_shape.push_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Tensor tensor(list->element_dtype,
|
||||||
|
tensorflow::TensorShape(tf_shape));
|
||||||
|
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||||
|
if (!attr_or.ok()) return matchFailure();
|
||||||
|
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract individual tensor list element and combine them using the tf.Pack
|
||||||
|
// op.
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
llvm::SmallVector<Value, 4> values;
|
||||||
|
values.reserve(tensors.size());
|
||||||
|
for (const tensorflow::Tensor &tensor : tensors) {
|
||||||
|
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||||
|
if (!attr_or.ok()) return matchFailure();
|
||||||
|
|
||||||
|
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
|
||||||
|
values.push_back(value);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<TF::PackOp>(
|
||||||
|
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConvertTensorListSetItem
|
struct ConvertTensorListSetItem
|
||||||
: public OpConversionPattern<TF::TensorListSetItemOp> {
|
: public OpConversionPattern<TF::TensorListSetItemOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
@ -768,7 +854,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
|||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns
|
patterns
|
||||||
.insert<ConvertEmptyTensorList, ConvertIdentity,
|
.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
|
||||||
ConvertTensorListFromTensor, ConvertTensorListGetItem,
|
ConvertTensorListFromTensor, ConvertTensorListGetItem,
|
||||||
ConvertTensorListLength, ConvertTensorListPushBack,
|
ConvertTensorListLength, ConvertTensorListPushBack,
|
||||||
ConvertTensorListReserve, ConvertTensorListSetItem,
|
ConvertTensorListReserve, ConvertTensorListSetItem,
|
||||||
|
Loading…
Reference in New Issue
Block a user