Rewrite the module whitelist into a basic rule-based set that allows both explicit whitelisting and non-whitelisting (e.g. to override a broader whitelist).

PiperOrigin-RevId: 250323405
This commit is contained in:
Dan Moldovan 2019-05-28 11:15:57 -07:00 committed by TensorFlower Gardener
parent 2a12536258
commit b211c7a053
2 changed files with 64 additions and 22 deletions
tensorflow/python/autograph

View File

@ -18,17 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.python.autograph import utils
PYTHON_LITERALS = {
'None': None,
'False': False,
'True': True,
'float': float,
}
def _internal_name(name):
"""This function correctly resolves internal and external names."""
reference_name = utils.__name__
@ -47,14 +41,58 @@ def _internal_name(name):
return root_prefix + '.' + name
DEFAULT_UNCOMPILED_MODULES = set((
('tensorflow',),
(_internal_name('tensorflow'),),
# TODO(mdan): Remove once the conversion process is optimized.
('tensorflow_probability',),
(_internal_name('tensorflow_probability'),),
class Rule(object):
"""Base class for conversion rules."""
def __init__(self, module_prefix):
self._prefix = module_prefix
def matches(self, module_name):
return (module_name.startswith(self._prefix + '.') or
module_name == self._prefix)
class Action(enum.Enum):
NONE = 0
CONVERT = 1
DO_NOT_CONVERT = 2
class DoNotConvert(Rule):
"""Indicates that this module should be not converted."""
def __str__(self):
return 'DoNotConvert rule for {}'.format(self._prefix)
def get_action(self, module):
if self.matches(module.__name__):
return Action.DO_NOT_CONVERT
return Action.NONE
class Convert(Rule):
"""Indicates that this module should be converted."""
def __str__(self):
return 'Convert rule for {}'.format(self._prefix)
def get_action(self, module):
if self.matches(module.__name__):
return Action.CONVERT
return Action.NONE
# This list is evaluated in order and stops at the first rule that tests True
# for a definitely_convert of definitely_bypass call.
CONVERSION_RULES = (
DoNotConvert('tensorflow'),
DoNotConvert(_internal_name('tensorflow')),
# TODO(b/133417201): Remove.
DoNotConvert('tensorflow_probability'),
DoNotConvert(_internal_name('tensorflow_probability')),
# TODO(b/130313089): Remove.
('numpy',),
# TODO(mdan): Might need to add "thread" as well?
('threading',),
))
DoNotConvert('numpy'),
DoNotConvert('threading'),
)

View File

@ -333,11 +333,15 @@ def is_whitelisted_for_graph(o, check_call_override=True):
else:
m = tf_inspect.getmodule(o)
# Examples of callables that lack a __module__ property include builtins.
if hasattr(m, '__name__'):
# Builtins typically have unnamed modules.
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
for rule in config.CONVERSION_RULES:
action = rule.get_action(m)
if action == config.Action.CONVERT:
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
return False
elif action == config.Action.DO_NOT_CONVERT:
logging.log(2, 'Whitelisted: %s: %s', o, rule)
return True
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):