Use `is` instead of equality when checking for whitelisted modules, to avoid triggering side effects.
PiperOrigin-RevId: 312842395 Change-Id: Ie8294cdedb657adf69af90130ac354dff77220dc
This commit is contained in:
parent
c64097cb5f
commit
227024b31a
|
@ -18,13 +18,9 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
|
||||||
import copy
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||||
return f(*args)
|
return f(*args)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_of_known_loaded_module(f, module_name):
|
||||||
|
mod = sys.modules.get(module_name, None)
|
||||||
|
if mod is None:
|
||||||
|
return False
|
||||||
|
if any(v is not None for v in mod.__dict__.values() if f is v):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _is_known_loaded_type(f, module_name, entity_name):
|
def _is_known_loaded_type(f, module_name, entity_name):
|
||||||
"""Tests whether the function or method is an instance of a known type."""
|
"""Tests whether the function or method is an instance of a known type."""
|
||||||
if (module_name not in sys.modules or
|
if (module_name not in sys.modules or
|
||||||
|
@ -511,7 +516,8 @@ def converted_call(f,
|
||||||
# Other built-in modules are permanently whitelisted.
|
# Other built-in modules are permanently whitelisted.
|
||||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||||
if any(
|
if any(
|
||||||
f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
|
_is_of_known_loaded_module(f, m)
|
||||||
|
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
|
@ -60,6 +61,23 @@ class ApiTest(test.TestCase):
|
||||||
|
|
||||||
self.assertEqual(5, tc.no_arg(2))
|
self.assertEqual(5, tc.no_arg(2))
|
||||||
|
|
||||||
|
def test_converted_call_avoids_triggering_operators(self):
|
||||||
|
|
||||||
|
test_self = self
|
||||||
|
|
||||||
|
class Pair(collections.namedtuple('Pair', ['a', 'b'])):
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self.a + self.b
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
test_self.fail('Triggered operator')
|
||||||
|
|
||||||
|
p = Pair(constant_op.constant(1), constant_op.constant(2))
|
||||||
|
|
||||||
|
x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE)
|
||||||
|
self.assertIsNotNone(self.evaluate(x), 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||||
|
|
Loading…
Reference in New Issue