Rewrite the module whitelist into a basic rule-based set that allows both explicit whitelisting and non-whitelisting (e.g. to override a broader whitelist).
PiperOrigin-RevId: 250323405
This commit is contained in:
parent
2a12536258
commit
b211c7a053
@ -18,17 +18,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import enum
|
||||||
|
|
||||||
from tensorflow.python.autograph import utils
|
from tensorflow.python.autograph import utils
|
||||||
|
|
||||||
|
|
||||||
PYTHON_LITERALS = {
|
|
||||||
'None': None,
|
|
||||||
'False': False,
|
|
||||||
'True': True,
|
|
||||||
'float': float,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _internal_name(name):
|
def _internal_name(name):
|
||||||
"""This function correctly resolves internal and external names."""
|
"""This function correctly resolves internal and external names."""
|
||||||
reference_name = utils.__name__
|
reference_name = utils.__name__
|
||||||
@ -47,14 +41,58 @@ def _internal_name(name):
|
|||||||
return root_prefix + '.' + name
|
return root_prefix + '.' + name
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_UNCOMPILED_MODULES = set((
|
class Rule(object):
|
||||||
('tensorflow',),
|
"""Base class for conversion rules."""
|
||||||
(_internal_name('tensorflow'),),
|
|
||||||
# TODO(mdan): Remove once the conversion process is optimized.
|
def __init__(self, module_prefix):
|
||||||
('tensorflow_probability',),
|
self._prefix = module_prefix
|
||||||
(_internal_name('tensorflow_probability'),),
|
|
||||||
|
def matches(self, module_name):
|
||||||
|
return (module_name.startswith(self._prefix + '.') or
|
||||||
|
module_name == self._prefix)
|
||||||
|
|
||||||
|
|
||||||
|
class Action(enum.Enum):
|
||||||
|
NONE = 0
|
||||||
|
CONVERT = 1
|
||||||
|
DO_NOT_CONVERT = 2
|
||||||
|
|
||||||
|
|
||||||
|
class DoNotConvert(Rule):
|
||||||
|
"""Indicates that this module should be not converted."""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'DoNotConvert rule for {}'.format(self._prefix)
|
||||||
|
|
||||||
|
def get_action(self, module):
|
||||||
|
if self.matches(module.__name__):
|
||||||
|
return Action.DO_NOT_CONVERT
|
||||||
|
return Action.NONE
|
||||||
|
|
||||||
|
|
||||||
|
class Convert(Rule):
|
||||||
|
"""Indicates that this module should be converted."""
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'Convert rule for {}'.format(self._prefix)
|
||||||
|
|
||||||
|
def get_action(self, module):
|
||||||
|
if self.matches(module.__name__):
|
||||||
|
return Action.CONVERT
|
||||||
|
return Action.NONE
|
||||||
|
|
||||||
|
|
||||||
|
# This list is evaluated in order and stops at the first rule that tests True
|
||||||
|
# for a definitely_convert of definitely_bypass call.
|
||||||
|
CONVERSION_RULES = (
|
||||||
|
DoNotConvert('tensorflow'),
|
||||||
|
DoNotConvert(_internal_name('tensorflow')),
|
||||||
|
|
||||||
|
# TODO(b/133417201): Remove.
|
||||||
|
DoNotConvert('tensorflow_probability'),
|
||||||
|
DoNotConvert(_internal_name('tensorflow_probability')),
|
||||||
|
|
||||||
# TODO(b/130313089): Remove.
|
# TODO(b/130313089): Remove.
|
||||||
('numpy',),
|
DoNotConvert('numpy'),
|
||||||
# TODO(mdan): Might need to add "thread" as well?
|
DoNotConvert('threading'),
|
||||||
('threading',),
|
)
|
||||||
))
|
|
||||||
|
@ -333,11 +333,15 @@ def is_whitelisted_for_graph(o, check_call_override=True):
|
|||||||
else:
|
else:
|
||||||
m = tf_inspect.getmodule(o)
|
m = tf_inspect.getmodule(o)
|
||||||
|
|
||||||
|
# Examples of callables that lack a __module__ property include builtins.
|
||||||
if hasattr(m, '__name__'):
|
if hasattr(m, '__name__'):
|
||||||
# Builtins typically have unnamed modules.
|
for rule in config.CONVERSION_RULES:
|
||||||
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
action = rule.get_action(m)
|
||||||
if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
|
if action == config.Action.CONVERT:
|
||||||
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
|
||||||
|
return False
|
||||||
|
elif action == config.Action.DO_NOT_CONVERT:
|
||||||
|
logging.log(2, 'Whitelisted: %s: %s', o, rule)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user