Convert EmptyTensorList and TensorListPushBack ops in tfl_lower_static_tensor_list pass

* Define EmptyTensorList op by sharing the common members with TensorListReserve op because all members except the summary, description and arguments are common.
* Introduce a new template RewritePattern to share logic between EmptyTensorList and TensorListReserve op. These two only differs in the number of elements present in the initialized tensor list.
* Convert TensorListPushBack by expanding the item to match rank of the input and then concatenating them.

PiperOrigin-RevId: 261054481
This commit is contained in:
Smit Hinsu 2019-07-31 21:03:25 -07:00 committed by TensorFlower Gardener
parent 1e91bd2715
commit 9ac55fbefc
4 changed files with 230 additions and 68 deletions

View File

@ -92,13 +92,14 @@ func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<
return %1 : tensor<?x?x?xf32>
// CHECK-LABEL: tensorlistReserve
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
// CHECK: return %3 : tensor<?x?x?xf32>
// CHECK-DAG: [[ZERO1:%cst.*]] = constant dense<0> : tensor<i32>
// CHECK-DAG: [[ZERO2:%cst.*]] = constant dense<0> : tensor<i32>
// CHECK-DAG: [[DIM0:%.*]] = "tf.ExpandDims"(%arg1, [[ZERO1]]) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
// CHECK: return [[RESULT]] : tensor<?x?x?xf32>
}
func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
@ -112,6 +113,36 @@ func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>
// CHECK: return [[RESULT2]] : tensor<*xf32>
}
func @EmptyTensorList(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
%0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
// CHECK-LABEL: EmptyTensorList
// CHECK-SAME: ([[ELEM_SHAPE:%.*]]: tensor<3xi32>, [[MAX_ELEMS:%.*]]: tensor<i32>, [[IDX:%.*]]: tensor<i32>)
// CHECK-DAG: [[DIM0:%cst.*]] = constant dense<0> : tensor<1xi32>
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
// CHECK: return [[RESULT]] : tensor<?x?x?xf32>
}
func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<10xf32>) -> tensor<?x10xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListPushBack"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<?x10xf32>
return %2 : tensor<?x10xf32>
// CHECK-LABEL: tensorlistPushBack
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>, [[ITEM:%.*]]: tensor<10xf32>)
// CHECK: [[ZERO:%.*]] = constant dense<0> : tensor<i32>
// CHECK: [[EXP_ITEM:%.*]] = "tf.ExpandDims"([[ITEM]], [[ZERO]]) {{.*}} -> tensor<1x10xf32>
// CHECK: [[RESULT:%.*]] = "tf.Concat"(%cst, [[INPUT]], [[EXP_ITEM]]) {N = 2 : i64} : {{.*}} -> tensor<?x10xf32>
// CHECK: return [[RESULT]] : tensor<?x10xf32>
}
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>

View File

@ -196,48 +196,52 @@ struct ConvertTFTensorListSetItem
}
};
struct ConvertTFTensorListReserve
: public OpRewritePattern<TF::TensorListReserveOp> {
explicit ConvertTFTensorListReserve(MLIRContext *context)
: OpRewritePattern<TF::TensorListReserveOp>(context, 1) {}
// Rewrites op of the template type initializing a TensorList with a list of ops
// to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method.
template <typename OpT>
struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
explicit ConvertTFTensorListInitOp(MLIRContext *context)
: OpRewritePattern<OpT>(context, 1) {}
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0;
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(TF::TensorListReserveOp op,
PatternMatchResult matchAndRewrite(OpT op,
PatternRewriter &rewriter) const override {
auto element_shape = op.element_shape();
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
auto num_elements = op.num_elements();
Type element_dtype = op.element_dtype();
int64_t result_rank = -1; // -1 means unknown result rank.
Type result_type = rewriter.getTensorType(element_dtype);
if (auto element_type = op.element_type().dyn_cast<RankedTensorType>()) {
if (auto element_type =
op.element_type().template dyn_cast<RankedTensorType>()) {
result_rank = element_type.getRank() + 1;
// If element type is ranked, then result type will have unknown leading
// dimension and element shape for the following dimensions.
//
// Note: leading dim is not inferred here even if num_elements input is a
// constant.
// Note: leading dim is not inferred here even when it is a constant.
SmallVector<int64_t, 4> result_shape = {-1};
ArrayRef<int64_t> shape = element_type.getShape();
result_shape.append(shape.begin(), shape.end());
result_type = rewriter.getTensorType(result_shape, element_dtype);
}
// The output shape of the result tensor should be [num_elements +
// element_shape].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
scalar_zero);
// Create a 1-D RankedTensorType for result's shape. Number of elements in
// it is equal to the rank of the result, if known. Otherwise, the number of
// elements are unknown and represented with -1. In both cases, we can
// specify dimension using rank of the result.
Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
// Add number of elements as the prefix to the element shape to get shape of
// the output tensor.
auto leading_dim = GetNumElements(op, &rewriter);
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto list_shape = rewriter.create<TF::ConcatOp>(
op.getLoc(), shape_type, scalar_zero,
ArrayRef<Value *>({leading_dim, element_shape}),
@ -250,6 +254,89 @@ struct ConvertTFTensorListReserve
auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
return Pattern::matchSuccess();
}
};
struct ConvertTFTensorListReserve
: public ConvertTFTensorListInitOp<TF::TensorListReserveOp> {
explicit ConvertTFTensorListReserve(MLIRContext *context)
: ConvertTFTensorListInitOp(context) {}
Value *GetNumElements(TF::TensorListReserveOp op,
PatternRewriter *rewriter) const override {
auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0);
auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), rewriter->getTensorType({1}, shape_dtype),
op.num_elements(), scalar_zero);
}
};
// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
// supports patterns involving variadic operand ops.
//
// Note that we ignore the second operand `max_num_elements` as we don't have
// any restrictions on the number of elements we can support. So this may
// have a different behavior compared to TensorFlow in case of errors.
struct ConvertTFEmptyTensorList
: public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> {
explicit ConvertTFEmptyTensorList(MLIRContext *context)
: ConvertTFTensorListInitOp(context) {}
Value *GetNumElements(TF::EmptyTensorListOp op,
PatternRewriter *rewriter) const override {
return CreateI32SplatConst(op, rewriter, {1}, 0);
}
};
struct ConvertTFTensorListPushBack : public RewritePattern {
explicit ConvertTFTensorListPushBack(MLIRContext *context)
: RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1,
context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
Value *item = push_back_op.tensor();
Type dtype = getElementTypeOrSelf(*item);
// Returns a new type by prepending the specified dimension to the shape of
// the given type if it is a ranked type.
auto with_leading_dim = [&](int64_t dim, Type type) -> Type {
if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
llvm::SmallVector<int64_t, 4> shape = {dim};
shape.append(ty.getShape().begin(), ty.getShape().end());
return rewriter.getTensorType(shape, dtype);
}
return rewriter.getTensorType(dtype);
};
// Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op.
Type expanded_item_type = with_leading_dim(1, item->getType());
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero);
// If the variant type in the output handle has item shape available, use it
// to derive the output shape by setting unknown leading dimension.
// Otherwise, result type will be of unranked type.
Type handle_type = push_back_op.output_handle()->getType();
TF::VariantType handle_dtype =
getElementTypeOrSelf(handle_type).cast<TF::VariantType>();
Type result_type = rewriter.getTensorType(dtype);
if (!handle_dtype.getSubtypes().empty()) {
result_type = with_leading_dim(-1, handle_dtype.getSubtypes()[0]);
}
// Concatenate tensor stored in the input handle with the expanded item to
// get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, result_type, scalar_zero,
ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}),
rewriter.getI64IntegerAttr(2));
return matchSuccess();
}
};
@ -354,6 +441,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
auto c = ConvertTFTensorListReserve(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) {
auto c = ConvertTFEmptyTensorList(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
auto c = TFL::ConvertTFTensorListGetItem(context);
rewriter->setInsertionPoint(op);
@ -366,6 +457,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
auto c = TFL::ConvertTFTensorListStack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) {
auto c = ConvertTFTensorListPushBack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(tf_op);

View File

@ -242,6 +242,22 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
}
//===----------------------------------------------------------------------===//
// EmptyTensorListOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(EmptyTensorListOp op) {
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
!IsOfRankOrUnranked(op.element_shape(), 1)) {
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
}
if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
return op.emitOpError("requires max_num_elements operand to be 0D tensor");
}
return success();
}
//===----------------------------------------------------------------------===//
// FakeQuantWithMinMaxArgsOp
//===----------------------------------------------------------------------===//

View File

@ -30,6 +30,37 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
let verifier = [{
if (handle_dtype().getSubtypes().size() != 1) {
return emitOpError(
"must have exactly one subtype in the result variant type");
}
return Verify(*this);
}];
DerivedTypeAttr element_dtype = DerivedTypeAttr<
"return getElementTypeOrSelf(element_type());">;
let extraClassDeclaration = [{
// Returns type of the TensorList element produced by this op.
TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
// Returns data type of the result handle. Returned type contains type of
// the TensorList element as a subtype.
VariantType handle_dtype() {
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
}
}];
}
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
// its type encoding the tensor's shape and data type.
def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> {
@ -55,6 +86,24 @@ def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> {
let hasFolder = 1;
}
def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
let summary = "Creates and returns an empty tensor list.";
let description = [{
All list elements must be tensors of dtype element_dtype and shape compatible
with element_shape.
handle: an empty tensor list.
element_dtype: the type of elements in the list.
element_shape: a shape compatible with that of elements in the list.
}];
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
I32Tensor:$max_num_elements
);
}
// TODO(fengliuai): The tf.Identity is side-effect free and it doesn't change
// the status of the system during the execution. However it shouldn't be folded
// in general if it used to serve for caching and some other invariant checks,
@ -191,51 +240,6 @@ Inserts a placeholder for a tensor that will be always fed.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_TensorListReserveOp : TF_Op<"TensorListReserve", [NoSideEffect]> {
let summary = "List of the given size with empty elements.";
let description = [{
element_shape: the shape of the future elements of the list
num_elements: the number of elements to reserve
handle: the output list
element_dtype: the desired type of elements in the list.
}];
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
I32Tensor:$num_elements
);
let results = (outs
TF_VariantTensor:$handle
);
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
let verifier = [{
if (handle_dtype().getSubtypes().size() != 1) {
return emitOpError(
"must have exactly one subtype in the result variant type");
}
return Verify(*this);
}];
DerivedTypeAttr element_dtype = DerivedTypeAttr<
"return getElementTypeOrSelf(element_type());">;
let extraClassDeclaration = [{
// Returns type of the TensorList element produced by this op.
TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
// Returns data type of the result handle. Returned type contains type of
// the TensorList element as a subtype.
VariantType handle_dtype() {
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
}
}];
}
def TF_WhileOp : TF_Op<"While", []> {
let summary = [{
output = input; While (Cond(output)) { output = Body(output) }
@ -278,4 +282,20 @@ body: A function that takes a list of tensors and returns another
}];
}
def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
let summary = "List of the given size with empty elements.";
let description = [{
element_shape: the shape of the future elements of the list
num_elements: the number of elements to reserve
handle: the output list
element_dtype: the desired type of elements in the list.
}];
let arguments = (ins
TF_I32OrI64Tensor:$element_shape,
I32Tensor:$num_elements
);
}
#endif // TF_OPS