diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 49be29065fe..45b8c9e5fb2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( // TensorFlow operations that doesn't have operands and results of type // variant are legal. Here, we don't distinguish between variants encoding // TensorList or some other type as that information is not available here. - // This constraint should be relaxed to support other variant types in TFLite. + // Partial legalization is used below to still allow ops with variant types + // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>(); @@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListPushBack, ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListStack, ConvertTensorListResize, ConvertWhile>(context); - return applyFullConversion(func, target, patterns); + return applyPartialConversion(func, target, patterns); } void LowerStaticTensorListPass::runOnOperation() {