Use `is` instead of equality when checking for whitelisted modules, to avoid triggering side effects.

PiperOrigin-RevId: 312842395
Change-Id: Ie8294cdedb657adf69af90130ac354dff77220dc
This commit is contained in:
Dan Moldovan 2020-05-22 10:49:01 -07:00 committed by TensorFlower Gardener
parent c64097cb5f
commit 227024b31a
2 changed files with 29 additions and 5 deletions

View File

@ -18,13 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import inspect
import os
import pdb
import re
import sys
import textwrap
import traceback
@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
return f(*args)
def _is_of_known_loaded_module(f, module_name):
mod = sys.modules.get(module_name, None)
if mod is None:
return False
if any(v is not None for v in mod.__dict__.values() if f is v):
return True
return False
def _is_known_loaded_type(f, module_name, entity_name):
"""Tests whether the function or method is an instance of a known type."""
if (module_name not in sys.modules or
@ -511,7 +516,8 @@ def converted_call(f,
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any(
f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
_is_of_known_loaded_module(f, m)
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return _call_unconverted(f, args, kwargs, options)

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.python.autograph.core import converter
@ -60,6 +61,23 @@ class ApiTest(test.TestCase):
self.assertEqual(5, tc.no_arg(2))
def test_converted_call_avoids_triggering_operators(self):
test_self = self
class Pair(collections.namedtuple('Pair', ['a', 'b'])):
def __call__(self):
return self.a + self.b
def __eq__(self, other):
test_self.fail('Triggered operator')
p = Pair(constant_op.constant(1), constant_op.constant(2))
x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE)
self.assertIsNotNone(self.evaluate(x), 3)
if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'