diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 3ebb5824b7f..98e19fdde86 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -18,13 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy import functools import inspect import os -import pdb -import re import sys import textwrap import traceback @@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True): return f(*args) +def _is_of_known_loaded_module(f, module_name): + mod = sys.modules.get(module_name, None) + if mod is None: + return False + if any(v is not None for v in mod.__dict__.values() if f is v): + return True + return False + + def _is_known_loaded_type(f, module_name, entity_name): """Tests whether the function or method is an instance of a known type.""" if (module_name not in sys.modules or @@ -511,7 +516,8 @@ def converted_call(f, # Other built-in modules are permanently whitelisted. # TODO(mdan): Figure out how to do this consistently for all stdlib modules. if any( - f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)): + _is_of_known_loaded_module(f, m) + for m in ('collections', 'pdb', 'copy', 'inspect', 're')): logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) return _call_unconverted(f, args, kwargs, options) diff --git a/tensorflow/python/autograph/impl/api_py3_test.py b/tensorflow/python/autograph/impl/api_py3_test.py index df6544928bf..c460e478008 100644 --- a/tensorflow/python/autograph/impl/api_py3_test.py +++ b/tensorflow/python/autograph/impl/api_py3_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os from tensorflow.python.autograph.core import converter @@ -60,6 +61,23 @@ class ApiTest(test.TestCase): self.assertEqual(5, tc.no_arg(2)) + def test_converted_call_avoids_triggering_operators(self): + + test_self = self + + class Pair(collections.namedtuple('Pair', ['a', 'b'])): + + def __call__(self): + return self.a + self.b + + def __eq__(self, other): + test_self.fail('Triggered operator') + + p = Pair(constant_op.constant(1), constant_op.constant(2)) + + x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE) + self.assertIsNotNone(self.evaluate(x), 3) + if __name__ == '__main__': os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'