Convert EmptyTensorList and TensorListPushBack ops in tfl_lower_static_tensor_list pass
* Define EmptyTensorList op by sharing the common members with TensorListReserve op because all members except the summary, description and arguments are common. * Introduce a new template RewritePattern to share logic between EmptyTensorList and TensorListReserve op. These two only differs in the number of elements present in the initialized tensor list. * Convert TensorListPushBack by expanding the item to match rank of the input and then concatenating them. PiperOrigin-RevId: 261054481
This commit is contained in:
parent
1e91bd2715
commit
9ac55fbefc
@ -92,13 +92,14 @@ func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistReserve
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return %3 : tensor<?x?x?xf32>
|
||||
// CHECK-DAG: [[ZERO1:%cst.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK-DAG: [[ZERO2:%cst.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK-DAG: [[DIM0:%.*]] = "tf.ExpandDims"(%arg1, [[ZERO1]]) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||
// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
|
||||
@ -112,6 +113,36 @@ func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>
|
||||
// CHECK: return [[RESULT2]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @EmptyTensorList(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
|
||||
// CHECK-LABEL: EmptyTensorList
|
||||
// CHECK-SAME: ([[ELEM_SHAPE:%.*]]: tensor<3xi32>, [[MAX_ELEMS:%.*]]: tensor<i32>, [[IDX:%.*]]: tensor<i32>)
|
||||
// CHECK-DAG: [[DIM0:%cst.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||
// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<10xf32>) -> tensor<?x10xf32> {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListPushBack"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<?x10xf32>
|
||||
return %2 : tensor<?x10xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistPushBack
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>, [[ITEM:%.*]]: tensor<10xf32>)
|
||||
// CHECK: [[ZERO:%.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: [[EXP_ITEM:%.*]] = "tf.ExpandDims"([[ITEM]], [[ZERO]]) {{.*}} -> tensor<1x10xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Concat"(%cst, [[INPUT]], [[EXP_ITEM]]) {N = 2 : i64} : {{.*}} -> tensor<?x10xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
|
@ -196,48 +196,52 @@ struct ConvertTFTensorListSetItem
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListReserve
|
||||
: public OpRewritePattern<TF::TensorListReserveOp> {
|
||||
explicit ConvertTFTensorListReserve(MLIRContext *context)
|
||||
: OpRewritePattern<TF::TensorListReserveOp>(context, 1) {}
|
||||
// Rewrites op of the template type initializing a TensorList with a list of ops
|
||||
// to generate an equivalent raw tensor. Derived classes are required to
|
||||
// override GetNumElements method.
|
||||
template <typename OpT>
|
||||
struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
|
||||
explicit ConvertTFTensorListInitOp(MLIRContext *context)
|
||||
: OpRewritePattern<OpT>(context, 1) {}
|
||||
|
||||
// Create and return a 1-d tensor with exactly one element equal to the number
|
||||
// of list elements to initialize the output tensor list with.
|
||||
virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0;
|
||||
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(TF::TensorListReserveOp op,
|
||||
PatternMatchResult matchAndRewrite(OpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto element_shape = op.element_shape();
|
||||
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
||||
auto num_elements = op.num_elements();
|
||||
Type element_dtype = op.element_dtype();
|
||||
|
||||
int64_t result_rank = -1; // -1 means unknown result rank.
|
||||
Type result_type = rewriter.getTensorType(element_dtype);
|
||||
if (auto element_type = op.element_type().dyn_cast<RankedTensorType>()) {
|
||||
if (auto element_type =
|
||||
op.element_type().template dyn_cast<RankedTensorType>()) {
|
||||
result_rank = element_type.getRank() + 1;
|
||||
// If element type is ranked, then result type will have unknown leading
|
||||
// dimension and element shape for the following dimensions.
|
||||
//
|
||||
// Note: leading dim is not inferred here even if num_elements input is a
|
||||
// constant.
|
||||
// Note: leading dim is not inferred here even when it is a constant.
|
||||
SmallVector<int64_t, 4> result_shape = {-1};
|
||||
ArrayRef<int64_t> shape = element_type.getShape();
|
||||
result_shape.append(shape.begin(), shape.end());
|
||||
result_type = rewriter.getTensorType(result_shape, element_dtype);
|
||||
}
|
||||
|
||||
// The output shape of the result tensor should be [num_elements +
|
||||
// element_shape].
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
|
||||
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
|
||||
scalar_zero);
|
||||
|
||||
// Create a 1-D RankedTensorType for result's shape. Number of elements in
|
||||
// it is equal to the rank of the result, if known. Otherwise, the number of
|
||||
// elements are unknown and represented with -1. In both cases, we can
|
||||
// specify dimension using rank of the result.
|
||||
Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
|
||||
|
||||
// Add number of elements as the prefix to the element shape to get shape of
|
||||
// the output tensor.
|
||||
auto leading_dim = GetNumElements(op, &rewriter);
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto list_shape = rewriter.create<TF::ConcatOp>(
|
||||
op.getLoc(), shape_type, scalar_zero,
|
||||
ArrayRef<Value *>({leading_dim, element_shape}),
|
||||
@ -250,6 +254,89 @@ struct ConvertTFTensorListReserve
|
||||
auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
|
||||
return Pattern::matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListReserve
|
||||
: public ConvertTFTensorListInitOp<TF::TensorListReserveOp> {
|
||||
explicit ConvertTFTensorListReserve(MLIRContext *context)
|
||||
: ConvertTFTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::TensorListReserveOp op,
|
||||
PatternRewriter *rewriter) const override {
|
||||
auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0);
|
||||
auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
|
||||
return rewriter->create<TF::ExpandDimsOp>(
|
||||
op.getLoc(), rewriter->getTensorType({1}, shape_dtype),
|
||||
op.num_elements(), scalar_zero);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
|
||||
// supports patterns involving variadic operand ops.
|
||||
//
|
||||
// Note that we ignore the second operand `max_num_elements` as we don't have
|
||||
// any restrictions on the number of elements we can support. So this may
|
||||
// have a different behavior compared to TensorFlow in case of errors.
|
||||
struct ConvertTFEmptyTensorList
|
||||
: public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> {
|
||||
explicit ConvertTFEmptyTensorList(MLIRContext *context)
|
||||
: ConvertTFTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::EmptyTensorListOp op,
|
||||
PatternRewriter *rewriter) const override {
|
||||
return CreateI32SplatConst(op, rewriter, {1}, 0);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListPushBack : public RewritePattern {
|
||||
explicit ConvertTFTensorListPushBack(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
|
||||
Value *item = push_back_op.tensor();
|
||||
Type dtype = getElementTypeOrSelf(*item);
|
||||
|
||||
// Returns a new type by prepending the specified dimension to the shape of
|
||||
// the given type if it is a ranked type.
|
||||
auto with_leading_dim = [&](int64_t dim, Type type) -> Type {
|
||||
if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) {
|
||||
llvm::SmallVector<int64_t, 4> shape = {dim};
|
||||
shape.append(ty.getShape().begin(), ty.getShape().end());
|
||||
return rewriter.getTensorType(shape, dtype);
|
||||
}
|
||||
|
||||
return rewriter.getTensorType(dtype);
|
||||
};
|
||||
|
||||
// Expand the shape of the item so that it will have rank same as the input
|
||||
// tensor and it is compatible for the Concat Op.
|
||||
Type expanded_item_type = with_leading_dim(1, item->getType());
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), expanded_item_type, item, scalar_zero);
|
||||
|
||||
// If the variant type in the output handle has item shape available, use it
|
||||
// to derive the output shape by setting unknown leading dimension.
|
||||
// Otherwise, result type will be of unranked type.
|
||||
Type handle_type = push_back_op.output_handle()->getType();
|
||||
TF::VariantType handle_dtype =
|
||||
getElementTypeOrSelf(handle_type).cast<TF::VariantType>();
|
||||
Type result_type = rewriter.getTensorType(dtype);
|
||||
if (!handle_dtype.getSubtypes().empty()) {
|
||||
result_type = with_leading_dim(-1, handle_dtype.getSubtypes()[0]);
|
||||
}
|
||||
|
||||
// Concatenate tensor stored in the input handle with the expanded item to
|
||||
// get a tensor equivalent to the TensorList generated by this op.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, result_type, scalar_zero,
|
||||
ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -354,6 +441,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
auto c = ConvertTFTensorListReserve(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) {
|
||||
auto c = ConvertTFEmptyTensorList(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListGetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
@ -366,6 +457,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
auto c = TFL::ConvertTFTensorListStack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) {
|
||||
auto c = ConvertTFTensorListPushBack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(tf_op);
|
||||
|
@ -242,6 +242,22 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
RewriteListBuilder<DivWithSqrtDivisor>::build(results, context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EmptyTensorListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(EmptyTensorListOp op) {
|
||||
if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
|
||||
!IsOfRankOrUnranked(op.element_shape(), 1)) {
|
||||
return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
|
||||
}
|
||||
|
||||
if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
|
||||
return op.emitOpError("requires max_num_elements operand to be 0D tensor");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FakeQuantWithMinMaxArgsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -30,6 +30,37 @@ limitations under the License.
|
||||
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
|
||||
|
||||
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
if (handle_dtype().getSubtypes().size() != 1) {
|
||||
return emitOpError(
|
||||
"must have exactly one subtype in the result variant type");
|
||||
}
|
||||
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
DerivedTypeAttr element_dtype = DerivedTypeAttr<
|
||||
"return getElementTypeOrSelf(element_type());">;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns type of the TensorList element produced by this op.
|
||||
TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
|
||||
|
||||
// Returns data type of the result handle. Returned type contains type of
|
||||
// the TensorList element as a subtype.
|
||||
VariantType handle_dtype() {
|
||||
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with
|
||||
// its type encoding the tensor's shape and data type.
|
||||
def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> {
|
||||
@ -55,6 +86,24 @@ def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> {
|
||||
let summary = "Creates and returns an empty tensor list.";
|
||||
|
||||
let description = [{
|
||||
All list elements must be tensors of dtype element_dtype and shape compatible
|
||||
with element_shape.
|
||||
|
||||
handle: an empty tensor list.
|
||||
element_dtype: the type of elements in the list.
|
||||
element_shape: a shape compatible with that of elements in the list.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I32Tensor:$max_num_elements
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(fengliuai): The tf.Identity is side-effect free and it doesn't change
|
||||
// the status of the system during the execution. However it shouldn't be folded
|
||||
// in general if it used to serve for caching and some other invariant checks,
|
||||
@ -191,51 +240,6 @@ Inserts a placeholder for a tensor that will be always fed.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListReserveOp : TF_Op<"TensorListReserve", [NoSideEffect]> {
|
||||
let summary = "List of the given size with empty elements.";
|
||||
|
||||
let description = [{
|
||||
element_shape: the shape of the future elements of the list
|
||||
num_elements: the number of elements to reserve
|
||||
handle: the output list
|
||||
element_dtype: the desired type of elements in the list.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I32Tensor:$num_elements
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$handle
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
if (handle_dtype().getSubtypes().size() != 1) {
|
||||
return emitOpError(
|
||||
"must have exactly one subtype in the result variant type");
|
||||
}
|
||||
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
DerivedTypeAttr element_dtype = DerivedTypeAttr<
|
||||
"return getElementTypeOrSelf(element_type());">;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns type of the TensorList element produced by this op.
|
||||
TensorType element_type() { return handle_dtype().getSubtypes()[0]; }
|
||||
|
||||
// Returns data type of the result handle. Returned type contains type of
|
||||
// the TensorList element as a subtype.
|
||||
VariantType handle_dtype() {
|
||||
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_WhileOp : TF_Op<"While", []> {
|
||||
let summary = [{
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
@ -278,4 +282,20 @@ body: A function that takes a list of tensors and returns another
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
|
||||
let summary = "List of the given size with empty elements.";
|
||||
|
||||
let description = [{
|
||||
element_shape: the shape of the future elements of the list
|
||||
num_elements: the number of elements to reserve
|
||||
handle: the output list
|
||||
element_dtype: the desired type of elements in the list.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I32Tensor:$num_elements
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
Loading…
Reference in New Issue
Block a user