Fix whitelist bug which incorrectly matched all modules whose name begins with tensorflow, e.g. "tensorflow_foo". Remove tensorboard whitelist.
PiperOrigin-RevId: 247514491
This commit is contained in:
parent
6f980870f3
commit
e934a65fd4
@ -335,16 +335,10 @@ def is_whitelisted_for_graph(o, check_call_override=True):
|
|||||||
if hasattr(m, '__name__'):
|
if hasattr(m, '__name__'):
|
||||||
# Builtins typically have unnamed modules.
|
# Builtins typically have unnamed modules.
|
||||||
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
||||||
if m.__name__.startswith(prefix):
|
if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
|
||||||
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Temporary -- whitelist tensorboard modules.
|
|
||||||
# TODO(b/122731813): Remove.
|
|
||||||
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
|
|
||||||
logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
|
|
||||||
return True
|
|
||||||
|
|
||||||
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
||||||
logging.log(2, 'Whitelisted: %s: already converted', o)
|
logging.log(2, 'Whitelisted: %s: already converted', o)
|
||||||
return True
|
return True
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import imp
|
||||||
import gast
|
import gast
|
||||||
|
|
||||||
from tensorflow.python.autograph import utils
|
from tensorflow.python.autograph import utils
|
||||||
@ -46,6 +47,16 @@ class ConversionTest(test.TestCase):
|
|||||||
self.assertTrue(conversion.is_whitelisted_for_graph(utils))
|
self.assertTrue(conversion.is_whitelisted_for_graph(utils))
|
||||||
self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant))
|
self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant))
|
||||||
|
|
||||||
|
def test_is_whitelisted_for_graph_tensorflow_like(self):
|
||||||
|
|
||||||
|
tf_like = imp.new_module('tensorflow_foo')
|
||||||
|
def test_fn():
|
||||||
|
pass
|
||||||
|
tf_like.test_fn = test_fn
|
||||||
|
test_fn.__module__ = tf_like
|
||||||
|
|
||||||
|
self.assertFalse(conversion.is_whitelisted_for_graph(tf_like.test_fn))
|
||||||
|
|
||||||
def test_convert_entity_to_ast_unsupported_types(self):
|
def test_convert_entity_to_ast_unsupported_types(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
program_ctx = self._simple_program_ctx()
|
program_ctx = self._simple_program_ctx()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user