Use `is` instead of equality when checking for whitelisted modules, to avoid triggering side effects.

PiperOrigin-RevId: 312842395
Change-Id: Ie8294cdedb657adf69af90130ac354dff77220dc
This commit is contained in:
Dan Moldovan 2020-05-22 10:49:01 -07:00 committed by TensorFlower Gardener
parent c64097cb5f
commit 227024b31a
2 changed files with 29 additions and 5 deletions

View File

@ -18,13 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy
import functools import functools
import inspect import inspect
import os import os
import pdb
import re
import sys import sys
import textwrap import textwrap
import traceback import traceback
@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
return f(*args) return f(*args)
def _is_of_known_loaded_module(f, module_name):
mod = sys.modules.get(module_name, None)
if mod is None:
return False
if any(v is not None for v in mod.__dict__.values() if f is v):
return True
return False
def _is_known_loaded_type(f, module_name, entity_name): def _is_known_loaded_type(f, module_name, entity_name):
"""Tests whether the function or method is an instance of a known type.""" """Tests whether the function or method is an instance of a known type."""
if (module_name not in sys.modules or if (module_name not in sys.modules or
@ -511,7 +516,8 @@ def converted_call(f,
# Other built-in modules are permanently whitelisted. # Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules. # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any( if any(
f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)): _is_of_known_loaded_module(f, m)
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return _call_unconverted(f, args, kwargs, options) return _call_unconverted(f, args, kwargs, options)

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import os import os
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
@ -60,6 +61,23 @@ class ApiTest(test.TestCase):
self.assertEqual(5, tc.no_arg(2)) self.assertEqual(5, tc.no_arg(2))
def test_converted_call_avoids_triggering_operators(self):
test_self = self
class Pair(collections.namedtuple('Pair', ['a', 'b'])):
def __call__(self):
return self.a + self.b
def __eq__(self, other):
test_self.fail('Triggered operator')
p = Pair(constant_op.constant(1), constant_op.constant(2))
x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE)
self.assertIsNotNone(self.evaluate(x), 3)
if __name__ == '__main__': if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'