Use `is` instead of equality when checking for whitelisted modules, to avoid triggering side effects.
PiperOrigin-RevId: 312842395 Change-Id: Ie8294cdedb657adf69af90130ac354dff77220dc
This commit is contained in:
parent
c64097cb5f
commit
227024b31a
|
@ -18,13 +18,9 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
|||
return f(*args)
|
||||
|
||||
|
||||
def _is_of_known_loaded_module(f, module_name):
|
||||
mod = sys.modules.get(module_name, None)
|
||||
if mod is None:
|
||||
return False
|
||||
if any(v is not None for v in mod.__dict__.values() if f is v):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_known_loaded_type(f, module_name, entity_name):
|
||||
"""Tests whether the function or method is an instance of a known type."""
|
||||
if (module_name not in sys.modules or
|
||||
|
@ -511,7 +516,8 @@ def converted_call(f,
|
|||
# Other built-in modules are permanently whitelisted.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
if any(
|
||||
f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
|
||||
_is_of_known_loaded_module(f, m)
|
||||
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
|
@ -60,6 +61,23 @@ class ApiTest(test.TestCase):
|
|||
|
||||
self.assertEqual(5, tc.no_arg(2))
|
||||
|
||||
def test_converted_call_avoids_triggering_operators(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Pair(collections.namedtuple('Pair', ['a', 'b'])):
|
||||
|
||||
def __call__(self):
|
||||
return self.a + self.b
|
||||
|
||||
def __eq__(self, other):
|
||||
test_self.fail('Triggered operator')
|
||||
|
||||
p = Pair(constant_op.constant(1), constant_op.constant(2))
|
||||
|
||||
x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE)
|
||||
self.assertIsNotNone(self.evaluate(x), 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||
|
|
Loading…
Reference in New Issue