Lower tf.Const op with TensorList in tfl-lower-static-tensor-list pass
tf.Const op is lowered to one Const op for each of the element in the TensorList which are then combined using tf.Pack op. PiperOrigin-RevId: 292073651 Change-Id: I7eb0fb0864b03ccadd83bb50ae7a469fa21a0196
This commit is contained in:
parent
138d4a0a58
commit
941950947d
tensorflow/compiler/mlir/lite
@ -295,13 +295,16 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:tensor_list",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
@ -1,5 +1,26 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: tensorlistConst
|
||||
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||
|
||||
// CHECK: return %[[LIST]]
|
||||
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
|
||||
// CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
|
||||
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||
|
||||
// CHECK: return %[[LIST]]
|
||||
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
|
||||
return %1 : tensor<0x3xi32>
|
||||
}
|
||||
|
||||
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
|
@ -80,10 +80,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
}
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// Execute this pass before `CanonicalizerPass` in case some TensorList
|
||||
// ops are constant folded into variant types.
|
||||
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
|
||||
// handle constant ops that produce `TensorList`.
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@ -57,6 +59,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/tensor_list.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
@ -162,6 +168,86 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
start_position, slice_size);
|
||||
}
|
||||
|
||||
// Converts tf.Const containing variant of type TensorList to a tensor of
|
||||
// primitive element types. Each of the individual tensor in the list is
|
||||
// converted to an ElementsAttr and then those are packed together using
|
||||
// tf.Pack op.
|
||||
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
TF::ConstOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Verify that the opaque elements attribute contains tensor of type variant
|
||||
// and scalar shape. The variant type should hold a TensorList.
|
||||
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
|
||||
if (!opaque_attr) return matchFailure();
|
||||
tensorflow::Tensor tensor;
|
||||
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
|
||||
return matchFailure();
|
||||
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
|
||||
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
|
||||
return matchFailure();
|
||||
|
||||
const tensorflow::TensorList *list =
|
||||
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
|
||||
if (!list) return matchFailure();
|
||||
|
||||
// Verify output type is variant and contains exactly one ranked subtypes.
|
||||
auto variant_ty =
|
||||
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
|
||||
if (!variant_ty) return matchFailure();
|
||||
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
|
||||
if (subtypes.size() != 1) return matchFailure();
|
||||
RankedTensorType list_element_ty =
|
||||
subtypes.front().dyn_cast<RankedTensorType>();
|
||||
if (!list_element_ty) return matchFailure();
|
||||
|
||||
// Extract tensor elements for the TensorList and construct result type
|
||||
// based on the number of elements and element shape.
|
||||
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
|
||||
llvm::SmallVector<int64_t, 4> result_shape = {
|
||||
static_cast<int64_t>(tensors.size())};
|
||||
result_shape.append(list_element_ty.getShape().begin(),
|
||||
list_element_ty.getShape().end());
|
||||
auto result_ty =
|
||||
RankedTensorType::get(result_shape, list_element_ty.getElementType());
|
||||
|
||||
// If the list is empty, directly create the final result instead of
|
||||
// creating the tf.Pack op. tf.Pack op requires at least one operand.
|
||||
if (tensors.empty()) {
|
||||
absl::InlinedVector<tensorflow::int64, 4> tf_shape;
|
||||
tf_shape.reserve(result_shape.size());
|
||||
for (int64_t dim : result_shape) {
|
||||
tf_shape.push_back(dim);
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor(list->element_dtype,
|
||||
tensorflow::TensorShape(tf_shape));
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Extract individual tensor list element and combine them using the tf.Pack
|
||||
// op.
|
||||
Location loc = op.getLoc();
|
||||
llvm::SmallVector<Value, 4> values;
|
||||
values.reserve(tensors.size());
|
||||
for (const tensorflow::Tensor &tensor : tensors) {
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
|
||||
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
|
||||
values.push_back(value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<TF::PackOp>(
|
||||
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListSetItem
|
||||
: public OpConversionPattern<TF::TensorListSetItemOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -768,7 +854,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns
|
||||
.insert<ConvertEmptyTensorList, ConvertIdentity,
|
||||
.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
|
||||
ConvertTensorListFromTensor, ConvertTensorListGetItem,
|
||||
ConvertTensorListLength, ConvertTensorListPushBack,
|
||||
ConvertTensorListReserve, ConvertTensorListSetItem,
|
||||
|
Loading…
Reference in New Issue
Block a user