Rewrite the module whitelist into a basic rule-based set that allows both explicit whitelisting and non-whitelisting (e.g. to override a broader whitelist).
PiperOrigin-RevId: 250323405
This commit is contained in:
parent
2a12536258
commit
b211c7a053
tensorflow/python/autograph
@ -18,17 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
|
||||
from tensorflow.python.autograph import utils
|
||||
|
||||
|
||||
PYTHON_LITERALS = {
|
||||
'None': None,
|
||||
'False': False,
|
||||
'True': True,
|
||||
'float': float,
|
||||
}
|
||||
|
||||
|
||||
def _internal_name(name):
|
||||
"""This function correctly resolves internal and external names."""
|
||||
reference_name = utils.__name__
|
||||
@ -47,14 +41,58 @@ def _internal_name(name):
|
||||
return root_prefix + '.' + name
|
||||
|
||||
|
||||
DEFAULT_UNCOMPILED_MODULES = set((
|
||||
('tensorflow',),
|
||||
(_internal_name('tensorflow'),),
|
||||
# TODO(mdan): Remove once the conversion process is optimized.
|
||||
('tensorflow_probability',),
|
||||
(_internal_name('tensorflow_probability'),),
|
||||
class Rule(object):
|
||||
"""Base class for conversion rules."""
|
||||
|
||||
def __init__(self, module_prefix):
|
||||
self._prefix = module_prefix
|
||||
|
||||
def matches(self, module_name):
|
||||
return (module_name.startswith(self._prefix + '.') or
|
||||
module_name == self._prefix)
|
||||
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
CONVERT = 1
|
||||
DO_NOT_CONVERT = 2
|
||||
|
||||
|
||||
class DoNotConvert(Rule):
|
||||
"""Indicates that this module should be not converted."""
|
||||
|
||||
def __str__(self):
|
||||
return 'DoNotConvert rule for {}'.format(self._prefix)
|
||||
|
||||
def get_action(self, module):
|
||||
if self.matches(module.__name__):
|
||||
return Action.DO_NOT_CONVERT
|
||||
return Action.NONE
|
||||
|
||||
|
||||
class Convert(Rule):
|
||||
"""Indicates that this module should be converted."""
|
||||
|
||||
def __str__(self):
|
||||
return 'Convert rule for {}'.format(self._prefix)
|
||||
|
||||
def get_action(self, module):
|
||||
if self.matches(module.__name__):
|
||||
return Action.CONVERT
|
||||
return Action.NONE
|
||||
|
||||
|
||||
# This list is evaluated in order and stops at the first rule that tests True
|
||||
# for a definitely_convert of definitely_bypass call.
|
||||
CONVERSION_RULES = (
|
||||
DoNotConvert('tensorflow'),
|
||||
DoNotConvert(_internal_name('tensorflow')),
|
||||
|
||||
# TODO(b/133417201): Remove.
|
||||
DoNotConvert('tensorflow_probability'),
|
||||
DoNotConvert(_internal_name('tensorflow_probability')),
|
||||
|
||||
# TODO(b/130313089): Remove.
|
||||
('numpy',),
|
||||
# TODO(mdan): Might need to add "thread" as well?
|
||||
('threading',),
|
||||
))
|
||||
DoNotConvert('numpy'),
|
||||
DoNotConvert('threading'),
|
||||
)
|
||||
|
@ -333,11 +333,15 @@ def is_whitelisted_for_graph(o, check_call_override=True):
|
||||
else:
|
||||
m = tf_inspect.getmodule(o)
|
||||
|
||||
# Examples of callables that lack a __module__ property include builtins.
|
||||
if hasattr(m, '__name__'):
|
||||
# Builtins typically have unnamed modules.
|
||||
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
||||
if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
|
||||
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
||||
for rule in config.CONVERSION_RULES:
|
||||
action = rule.get_action(m)
|
||||
if action == config.Action.CONVERT:
|
||||
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
|
||||
return False
|
||||
elif action == config.Action.DO_NOT_CONVERT:
|
||||
logging.log(2, 'Whitelisted: %s: %s', o, rule)
|
||||
return True
|
||||
|
||||
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
||||
|
Loading…
Reference in New Issue
Block a user