Don't flag unsupported variant type ops in tensorlist pass
TensorList pass need not know about all DT_VARIANT uses, so instead just use partial conversion. This would still flag/fail if one of the explicitly marked illegal ops are encountered. PiperOrigin-RevId: 313306614 Change-Id: I1e56d2ea8f82bf5a7b72f6507efa9310b04e1cad
This commit is contained in:
parent
a2f840d54e
commit
05653928da
@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
|||||||
// TensorFlow operations that doesn't have operands and results of type
|
// TensorFlow operations that doesn't have operands and results of type
|
||||||
// variant are legal. Here, we don't distinguish between variants encoding
|
// variant are legal. Here, we don't distinguish between variants encoding
|
||||||
// TensorList or some other type as that information is not available here.
|
// TensorList or some other type as that information is not available here.
|
||||||
// This constraint should be relaxed to support other variant types in TFLite.
|
// Partial legalization is used below to still allow ops with variant types
|
||||||
|
// still.
|
||||||
auto is_legal = [](Operation *op) {
|
auto is_legal = [](Operation *op) {
|
||||||
auto is_not_variant = [](Type ty) {
|
auto is_not_variant = [](Type ty) {
|
||||||
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
|
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
|
||||||
@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
|||||||
ConvertTensorListPushBack, ConvertTensorListReserve,
|
ConvertTensorListPushBack, ConvertTensorListReserve,
|
||||||
ConvertTensorListSetItem, ConvertTensorListStack,
|
ConvertTensorListSetItem, ConvertTensorListStack,
|
||||||
ConvertTensorListResize, ConvertWhile>(context);
|
ConvertTensorListResize, ConvertWhile>(context);
|
||||||
return applyFullConversion(func, target, patterns);
|
return applyPartialConversion(func, target, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LowerStaticTensorListPass::runOnOperation() {
|
void LowerStaticTensorListPass::runOnOperation() {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user