Don't flag unsupported variant type ops in tensorlist pass

TensorList pass need not know about all DT_VARIANT uses, so instead just use
partial conversion. This would still flag/fail if one of the explicitly marked illegal ops are encountered.

PiperOrigin-RevId: 313306614
Change-Id: I1e56d2ea8f82bf5a7b72f6507efa9310b04e1cad
This commit is contained in:
Jacques Pienaar 2020-05-26 18:27:40 -07:00 committed by TensorFlower Gardener
parent a2f840d54e
commit 05653928da

View File

@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
// TensorFlow operations that doesn't have operands and results of type
// variant are legal. Here, we don't distinguish between variants encoding
// TensorList or some other type as that information is not available here.
// This constraint should be relaxed to support other variant types in TFLite.
// Partial legalization is used below to still allow ops with variant types
// still.
auto is_legal = [](Operation *op) {
auto is_not_variant = [](Type ty) {
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
ConvertTensorListPushBack, ConvertTensorListReserve,
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile>(context);
return applyFullConversion(func, target, patterns);
return applyPartialConversion(func, target, patterns);
}
void LowerStaticTensorListPass::runOnOperation() {