diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py index f2b780d7fcd..855a690c4b2 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils.py +++ b/tensorflow/python/autograph/pyct/inspect_utils.py @@ -298,7 +298,7 @@ def getfutureimports(entity): Returns: A tuple of future strings """ - if not tf_inspect.isfunction(entity): + if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)): return tuple() return tuple(sorted(name for name, value in entity.__globals__.items() if getattr(value, '__module__', None) == '__future__')) diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py index 75b41d226f6..ab8e95e09c4 100644 --- a/tensorflow/python/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py @@ -415,12 +415,18 @@ class InspectUtilsTest(test.TestCase): self.assertTrue(inspect_utils.isbuiltin(zip)) self.assertFalse(inspect_utils.isbuiltin(function_decorator)) - def test_getfutureimports_simple_case(self): + def test_getfutureimports_functions(self): expected_imports = ('absolute_import', 'division', 'print_function', 'with_statement') self.assertEqual(inspect_utils.getfutureimports(future_import_module.f), expected_imports) + def test_getfutureimports_methods(self): + expected_imports = ('absolute_import', 'division', 'print_function', + 'with_statement') + self.assertEqual(inspect_utils.getfutureimports(future_import_module.Foo.f), + expected_imports) + def test_super_wrapper_for_dynamic_attrs(self): a = object() diff --git a/tensorflow/python/autograph/pyct/testing/future_import_module.py b/tensorflow/python/autograph/pyct/testing/future_import_module.py index a167322dbfe..95698a18268 100644 --- a/tensorflow/python/autograph/pyct/testing/future_import_module.py +++ b/tensorflow/python/autograph/pyct/testing/future_import_module.py @@ -24,3 +24,9 @@ from __future__ import with_statement def f(): print('foo') + + +class Foo(object): + + def f(self): + print('foo')