Convert EmptyTensorList and TensorListPushBack ops in tfl_lower_static_tensor_list pass
* Define EmptyTensorList op by sharing the common members with TensorListReserve op because all members except the summary, description and arguments are common. * Introduce a new template RewritePattern to share logic between EmptyTensorList and TensorListReserve op. These two only differs in the number of elements present in the initialized tensor list. * Convert TensorListPushBack by expanding the item to match rank of the input and then concatenating them. PiperOrigin-RevId: 261054481
This commit is contained in:
		
							parent
							
								
									1e91bd2715
								
							
						
					
					
						commit
						9ac55fbefc
					
				| @ -92,13 +92,14 @@ func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor< | ||||
|   return %1 : tensor<?x?x?xf32> | ||||
| 
 | ||||
| // CHECK-LABEL: tensorlistReserve | ||||
| // CHECK:  %cst = constant dense<0> : tensor<i32> | ||||
| // CHECK:  %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32> | ||||
| // CHECK:  %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> | ||||
| // CHECK:  %cst_0 = constant dense<0.000000e+00> : tensor<f32> | ||||
| // CHECK:  %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:  %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32> | ||||
| // CHECK:  return %3 : tensor<?x?x?xf32> | ||||
| // CHECK-DAG:  [[ZERO1:%cst.*]] = constant dense<0> : tensor<i32> | ||||
| // CHECK-DAG:  [[ZERO2:%cst.*]] = constant dense<0> : tensor<i32> | ||||
| // CHECK-DAG:  [[DIM0:%.*]] = "tf.ExpandDims"(%arg1, [[ZERO1]]) : (tensor<i32>, tensor<i32>) -> tensor<1xi32> | ||||
| // CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> | ||||
| // CHECK-DAG:  [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32> | ||||
| // CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32> | ||||
| // CHECK:      return [[RESULT]] : tensor<?x?x?xf32> | ||||
| } | ||||
| 
 | ||||
| func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> { | ||||
| @ -112,6 +113,36 @@ func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32> | ||||
| // CHECK:  return [[RESULT2]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
| func @EmptyTensorList(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> { | ||||
|   %0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>> | ||||
|   %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32> | ||||
|   return %1 : tensor<?x?x?xf32> | ||||
| 
 | ||||
| // CHECK-LABEL: EmptyTensorList | ||||
| // CHECK-SAME:  ([[ELEM_SHAPE:%.*]]: tensor<3xi32>, [[MAX_ELEMS:%.*]]: tensor<i32>, [[IDX:%.*]]: tensor<i32>) | ||||
| // CHECK-DAG:  [[DIM0:%cst.*]] = constant dense<0> : tensor<1xi32> | ||||
| // CHECK-DAG:  [[ZERO:%cst.*]] = constant dense<0> : tensor<i32> | ||||
| // CHECK-DAG:  [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> | ||||
| // CHECK-DAG:  [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor<f32> | ||||
| // CHECK:      [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor<f32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:      [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor<?x?x?x?xf32>, tensor<i32>) -> tensor<?x?x?xf32> | ||||
| // CHECK:      return [[RESULT]] : tensor<?x?x?xf32> | ||||
| } | ||||
| 
 | ||||
| func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<10xf32>) -> tensor<?x10xf32> { | ||||
|   %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>> | ||||
|   %1 = "tf.TensorListPushBack"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>> | ||||
|   %2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<?x10xf32> | ||||
|   return %2 : tensor<?x10xf32> | ||||
| 
 | ||||
| // CHECK-LABEL: tensorlistPushBack | ||||
| // CHECK-SAME:  ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>, [[ITEM:%.*]]: tensor<10xf32>) | ||||
| // CHECK:   [[ZERO:%.*]] = constant dense<0> : tensor<i32> | ||||
| // CHECK:   [[EXP_ITEM:%.*]] = "tf.ExpandDims"([[ITEM]], [[ZERO]]) {{.*}} -> tensor<1x10xf32> | ||||
| // CHECK:   [[RESULT:%.*]] = "tf.Concat"(%cst, [[INPUT]], [[EXP_ITEM]]) {N = 2 : i64} : {{.*}} -> tensor<?x10xf32> | ||||
| // CHECK:   return [[RESULT]] : tensor<?x10xf32> | ||||
| } | ||||
| 
 | ||||
| func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { | ||||
|   %cst = constant dense<3> : tensor<1xi32> | ||||
|   %cst_0 = constant dense<0> : tensor<i32> | ||||
|  | ||||
| @ -196,48 +196,52 @@ struct ConvertTFTensorListSetItem | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ConvertTFTensorListReserve | ||||
|     : public OpRewritePattern<TF::TensorListReserveOp> { | ||||
|   explicit ConvertTFTensorListReserve(MLIRContext *context) | ||||
|       : OpRewritePattern<TF::TensorListReserveOp>(context, 1) {} | ||||
| // Rewrites op of the template type initializing a TensorList with a list of ops
 | ||||
| // to generate an equivalent raw tensor. Derived classes are required to
 | ||||
| // override GetNumElements method.
 | ||||
| template <typename OpT> | ||||
| struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> { | ||||
|   explicit ConvertTFTensorListInitOp(MLIRContext *context) | ||||
|       : OpRewritePattern<OpT>(context, 1) {} | ||||
| 
 | ||||
|   // Create and return a 1-d tensor with exactly one element equal to the number
 | ||||
|   // of list elements to initialize the output tensor list with.
 | ||||
|   virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0; | ||||
| 
 | ||||
|   // Rewrites the original op into `tf.fill`. The result tensor shape is
 | ||||
|   // [num_element, element_shape]. All the values in the result tensor will be
 | ||||
|   // initialized to 0.
 | ||||
|   PatternMatchResult matchAndRewrite(TF::TensorListReserveOp op, | ||||
|   PatternMatchResult matchAndRewrite(OpT op, | ||||
|                                      PatternRewriter &rewriter) const override { | ||||
|     auto element_shape = op.element_shape(); | ||||
|     auto shape_dtype = getElementTypeOrSelf(element_shape->getType()); | ||||
|     auto num_elements = op.num_elements(); | ||||
|     Type element_dtype = op.element_dtype(); | ||||
| 
 | ||||
|     int64_t result_rank = -1;  // -1 means unknown result rank.
 | ||||
|     Type result_type = rewriter.getTensorType(element_dtype); | ||||
|     if (auto element_type = op.element_type().dyn_cast<RankedTensorType>()) { | ||||
|     if (auto element_type = | ||||
|             op.element_type().template dyn_cast<RankedTensorType>()) { | ||||
|       result_rank = element_type.getRank() + 1; | ||||
|       // If element type is ranked, then result type will have unknown leading
 | ||||
|       // dimension and element shape for the following dimensions.
 | ||||
|       //
 | ||||
|       // Note: leading dim is not inferred here even if num_elements input is a
 | ||||
|       // constant.
 | ||||
|       // Note: leading dim is not inferred here even when it is a constant.
 | ||||
|       SmallVector<int64_t, 4> result_shape = {-1}; | ||||
|       ArrayRef<int64_t> shape = element_type.getShape(); | ||||
|       result_shape.append(shape.begin(), shape.end()); | ||||
|       result_type = rewriter.getTensorType(result_shape, element_dtype); | ||||
|     } | ||||
| 
 | ||||
|     // The output shape of the result tensor should be [num_elements +
 | ||||
|     // element_shape].
 | ||||
|     auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); | ||||
|     auto leading_dim = rewriter.create<TF::ExpandDimsOp>( | ||||
|         op.getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements, | ||||
|         scalar_zero); | ||||
| 
 | ||||
|     // Create a 1-D RankedTensorType for result's shape. Number of elements in
 | ||||
|     // it is equal to the rank of the result, if known. Otherwise, the number of
 | ||||
|     // elements are unknown and represented with -1. In both cases, we can
 | ||||
|     // specify dimension using rank of the result.
 | ||||
|     Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype); | ||||
| 
 | ||||
|     // Add number of elements as the prefix to the element shape to get shape of
 | ||||
|     // the output tensor.
 | ||||
|     auto leading_dim = GetNumElements(op, &rewriter); | ||||
|     auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); | ||||
|     auto list_shape = rewriter.create<TF::ConcatOp>( | ||||
|         op.getLoc(), shape_type, scalar_zero, | ||||
|         ArrayRef<Value *>({leading_dim, element_shape}), | ||||
| @ -250,6 +254,89 @@ struct ConvertTFTensorListReserve | ||||
|     auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr); | ||||
| 
 | ||||
|     rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero); | ||||
|     return Pattern::matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ConvertTFTensorListReserve | ||||
|     : public ConvertTFTensorListInitOp<TF::TensorListReserveOp> { | ||||
|   explicit ConvertTFTensorListReserve(MLIRContext *context) | ||||
|       : ConvertTFTensorListInitOp(context) {} | ||||
| 
 | ||||
|   Value *GetNumElements(TF::TensorListReserveOp op, | ||||
|                         PatternRewriter *rewriter) const override { | ||||
|     auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0); | ||||
|     auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType()); | ||||
|     return rewriter->create<TF::ExpandDimsOp>( | ||||
|         op.getLoc(), rewriter->getTensorType({1}, shape_dtype), | ||||
|         op.num_elements(), scalar_zero); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
 | ||||
| // supports patterns involving variadic operand ops.
 | ||||
| //
 | ||||
| // Note that we ignore the second operand `max_num_elements` as we don't have
 | ||||
| // any restrictions on the number of elements we can support. So this may
 | ||||
| // have a different behavior compared to TensorFlow in case of errors.
 | ||||
| struct ConvertTFEmptyTensorList | ||||
|     : public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> { | ||||
|   explicit ConvertTFEmptyTensorList(MLIRContext *context) | ||||
|       : ConvertTFTensorListInitOp(context) {} | ||||
| 
 | ||||
|   Value *GetNumElements(TF::EmptyTensorListOp op, | ||||
|                         PatternRewriter *rewriter) const override { | ||||
|     return CreateI32SplatConst(op, rewriter, {1}, 0); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ConvertTFTensorListPushBack : public RewritePattern { | ||||
|   explicit ConvertTFTensorListPushBack(MLIRContext *context) | ||||
|       : RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1, | ||||
|                        context) {} | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite(Operation *op, | ||||
|                                      PatternRewriter &rewriter) const override { | ||||
|     TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op); | ||||
|     Value *item = push_back_op.tensor(); | ||||
|     Type dtype = getElementTypeOrSelf(*item); | ||||
| 
 | ||||
|     // Returns a new type by prepending the specified dimension to the shape of
 | ||||
|     // the given type if it is a ranked type.
 | ||||
|     auto with_leading_dim = [&](int64_t dim, Type type) -> Type { | ||||
|       if (RankedTensorType ty = type.dyn_cast<RankedTensorType>()) { | ||||
|         llvm::SmallVector<int64_t, 4> shape = {dim}; | ||||
|         shape.append(ty.getShape().begin(), ty.getShape().end()); | ||||
|         return rewriter.getTensorType(shape, dtype); | ||||
|       } | ||||
| 
 | ||||
|       return rewriter.getTensorType(dtype); | ||||
|     }; | ||||
| 
 | ||||
|     // Expand the shape of the item so that it will have rank same as the input
 | ||||
|     // tensor and it is compatible for the Concat Op.
 | ||||
|     Type expanded_item_type = with_leading_dim(1, item->getType()); | ||||
|     auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); | ||||
|     auto expanded_item = rewriter.create<TF::ExpandDimsOp>( | ||||
|         op->getLoc(), expanded_item_type, item, scalar_zero); | ||||
| 
 | ||||
|     // If the variant type in the output handle has item shape available, use it
 | ||||
|     // to derive the output shape by setting unknown leading dimension.
 | ||||
|     // Otherwise, result type will be of unranked type.
 | ||||
|     Type handle_type = push_back_op.output_handle()->getType(); | ||||
|     TF::VariantType handle_dtype = | ||||
|         getElementTypeOrSelf(handle_type).cast<TF::VariantType>(); | ||||
|     Type result_type = rewriter.getTensorType(dtype); | ||||
|     if (!handle_dtype.getSubtypes().empty()) { | ||||
|       result_type = with_leading_dim(-1, handle_dtype.getSubtypes()[0]); | ||||
|     } | ||||
| 
 | ||||
|     // Concatenate tensor stored in the input handle with the expanded item to
 | ||||
|     // get a tensor equivalent to the TensorList generated by this op.
 | ||||
|     rewriter.replaceOpWithNewOp<TF::ConcatOp>( | ||||
|         op, result_type, scalar_zero, | ||||
|         ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}), | ||||
|         rewriter.getI64IntegerAttr(2)); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| @ -354,6 +441,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( | ||||
|         auto c = ConvertTFTensorListReserve(context); | ||||
|         rewriter->setInsertionPoint(op); | ||||
|         c.matchAndRewrite(tf_op, *rewriter); | ||||
|       } else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) { | ||||
|         auto c = ConvertTFEmptyTensorList(context); | ||||
|         rewriter->setInsertionPoint(op); | ||||
|         c.matchAndRewrite(tf_op, *rewriter); | ||||
|       } else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) { | ||||
|         auto c = TFL::ConvertTFTensorListGetItem(context); | ||||
|         rewriter->setInsertionPoint(op); | ||||
| @ -366,6 +457,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( | ||||
|         auto c = TFL::ConvertTFTensorListStack(context); | ||||
|         rewriter->setInsertionPoint(op); | ||||
|         c.matchAndRewrite(op, *rewriter); | ||||
|       } else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) { | ||||
|         auto c = ConvertTFTensorListPushBack(context); | ||||
|         rewriter->setInsertionPoint(op); | ||||
|         c.matchAndRewrite(op, *rewriter); | ||||
|       } else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) { | ||||
|         if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); | ||||
|         UpdateWhileFunctionType(tf_op); | ||||
|  | ||||
| @ -242,6 +242,22 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, | ||||
|   RewriteListBuilder<DivWithSqrtDivisor>::build(results, context); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // EmptyTensorListOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| static LogicalResult Verify(EmptyTensorListOp op) { | ||||
|   if (!IsOfRankOrUnranked(op.element_shape(), 0) && | ||||
|       !IsOfRankOrUnranked(op.element_shape(), 1)) { | ||||
|     return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); | ||||
|   } | ||||
| 
 | ||||
|   if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { | ||||
|     return op.emitOpError("requires max_num_elements operand to be 0D tensor"); | ||||
|   } | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // FakeQuantWithMinMaxArgsOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | ||||
| @ -30,6 +30,37 @@ limitations under the License. | ||||
| 
 | ||||
| include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" | ||||
| 
 | ||||
| class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> { | ||||
|   let results = (outs | ||||
|     TF_VariantTensor:$handle | ||||
|   ); | ||||
| 
 | ||||
|   TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; | ||||
| 
 | ||||
|   let verifier = [{ | ||||
|     if (handle_dtype().getSubtypes().size() != 1) { | ||||
|       return emitOpError( | ||||
|           "must have exactly one subtype in the result variant type"); | ||||
|     } | ||||
| 
 | ||||
|     return Verify(*this); | ||||
|   }]; | ||||
| 
 | ||||
|   DerivedTypeAttr element_dtype = DerivedTypeAttr< | ||||
|       "return getElementTypeOrSelf(element_type());">; | ||||
| 
 | ||||
|   let extraClassDeclaration = [{ | ||||
|     // Returns type of the TensorList element produced by this op. | ||||
|     TensorType element_type() { return handle_dtype().getSubtypes()[0]; } | ||||
| 
 | ||||
|     // Returns data type of the result handle. Returned type contains type of | ||||
|     // the TensorList element as a subtype. | ||||
|     VariantType handle_dtype() { | ||||
|       return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>(); | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with | ||||
| // its type encoding the tensor's shape and data type. | ||||
| def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> { | ||||
| @ -55,6 +86,24 @@ def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> { | ||||
|   let hasFolder = 1; | ||||
| } | ||||
| 
 | ||||
| def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { | ||||
|   let summary = "Creates and returns an empty tensor list."; | ||||
| 
 | ||||
|   let description = [{ | ||||
| All list elements must be tensors of dtype element_dtype and shape compatible | ||||
| with element_shape. | ||||
| 
 | ||||
| handle: an empty tensor list. | ||||
| element_dtype: the type of elements in the list. | ||||
| element_shape: a shape compatible with that of elements in the list. | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins | ||||
|     TF_I32OrI64Tensor:$element_shape, | ||||
|     I32Tensor:$max_num_elements | ||||
|   ); | ||||
| } | ||||
| 
 | ||||
| // TODO(fengliuai): The tf.Identity is side-effect free and it doesn't change | ||||
| // the status of the system during the execution. However it shouldn't be folded | ||||
| // in general if it used to serve for caching and some other invariant checks, | ||||
| @ -191,51 +240,6 @@ Inserts a placeholder for a tensor that will be always fed. | ||||
|   TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; | ||||
| } | ||||
| 
 | ||||
| def TF_TensorListReserveOp : TF_Op<"TensorListReserve", [NoSideEffect]> { | ||||
|   let summary = "List of the given size with empty elements."; | ||||
| 
 | ||||
|   let description = [{ | ||||
| element_shape: the shape of the future elements of the list | ||||
| num_elements: the number of elements to reserve | ||||
| handle: the output list | ||||
| element_dtype: the desired type of elements in the list. | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins | ||||
|     TF_I32OrI64Tensor:$element_shape, | ||||
|     I32Tensor:$num_elements | ||||
|   ); | ||||
| 
 | ||||
|   let results = (outs | ||||
|     TF_VariantTensor:$handle | ||||
|   ); | ||||
| 
 | ||||
|   TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; | ||||
| 
 | ||||
|   let verifier = [{ | ||||
|     if (handle_dtype().getSubtypes().size() != 1) { | ||||
|       return emitOpError( | ||||
|           "must have exactly one subtype in the result variant type"); | ||||
|     } | ||||
| 
 | ||||
|     return Verify(*this); | ||||
|   }]; | ||||
| 
 | ||||
|   DerivedTypeAttr element_dtype = DerivedTypeAttr< | ||||
|       "return getElementTypeOrSelf(element_type());">; | ||||
| 
 | ||||
|   let extraClassDeclaration = [{ | ||||
|     // Returns type of the TensorList element produced by this op. | ||||
|     TensorType element_type() { return handle_dtype().getSubtypes()[0]; } | ||||
| 
 | ||||
|     // Returns data type of the result handle. Returned type contains type of | ||||
|     // the TensorList element as a subtype. | ||||
|     VariantType handle_dtype() { | ||||
|       return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>(); | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def TF_WhileOp : TF_Op<"While", []> { | ||||
|   let summary = [{ | ||||
| output = input; While (Cond(output)) { output = Body(output) } | ||||
| @ -278,4 +282,20 @@ body: A function that takes a list of tensors and returns another | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { | ||||
|   let summary = "List of the given size with empty elements."; | ||||
| 
 | ||||
|   let description = [{ | ||||
| element_shape: the shape of the future elements of the list | ||||
| num_elements: the number of elements to reserve | ||||
| handle: the output list | ||||
| element_dtype: the desired type of elements in the list. | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins | ||||
|     TF_I32OrI64Tensor:$element_shape, | ||||
|     I32Tensor:$num_elements | ||||
|   ); | ||||
| } | ||||
| 
 | ||||
| #endif // TF_OPS | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user