From b5c725611842a2abb3d504947a7895341c231d76 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 23 Mar 2020 16:59:53 -0700 Subject: [PATCH] Internal cleanup: retire support for converting whole classes, which is not supported in TF2 and is unlikely to have real uses. PiperOrigin-RevId: 302550027 Change-Id: I0127d56a3cc45ea38205df9a88480a75d25ea39c --- .../python/autograph/impl/conversion.py | 153 ++---------------- .../python/autograph/impl/conversion_test.py | 53 ------ 2 files changed, 13 insertions(+), 193 deletions(-) diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 3e062f2eba0..e14c8e2bfcf 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -52,12 +52,11 @@ from tensorflow.python.autograph.core import naming from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.pyct import ast_util -from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.pyct import loader from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import pretty_printer -from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct import transformer from tensorflow.python.autograph.utils import ag_logging as logging @@ -299,15 +298,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): """Creates a converted instance and binds it to match original entity.""" factory = converted_entity_info.get_factory() - # `factory` is currently bound to the empty module it was loaded from. - # It must instead be bound to the globals and closure from the original - # entity. - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - entity_globals = entity.__globals__ - entity_closure = entity.__closure__ or () - elif hasattr(entity, '__module__'): - entity_globals = sys.modules[entity.__module__].__dict__ - entity_closure = () + entity_globals = entity.__globals__ + entity_closure = entity.__closure__ or () assert len(entity_closure) == len(free_nonglobal_var_names) # Fit the original entity's cells to match the order of factory's cells. @@ -328,11 +320,10 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): ag_internal, converted_entity_info.source_map, converted_entity_info.get_module()) - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - # Attach the default argument to the converted function. - converted_entity.__defaults__ = entity.__defaults__ - if hasattr(entity, '__kwdefaults__'): - converted_entity.__kwdefaults__ = entity.__kwdefaults__ + # Attach the default argument to the converted function. + converted_entity.__defaults__ = entity.__defaults__ + if hasattr(entity, '__kwdefaults__'): + converted_entity.__kwdefaults__ = entity.__kwdefaults__ return converted_entity @@ -340,14 +331,11 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): def convert(entity, program_ctx): """Converts an entity into an equivalent entity.""" - if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity): - if not hasattr(entity, '__code__'): - raise ValueError('Cannot apply autograph to a function that doesn\'t ' - 'expose a __code__ object. If this is a @tf.function,' - ' try passing f.python_function instead.') - free_nonglobal_var_names = entity.__code__.co_freevars - else: - free_nonglobal_var_names = () + if not hasattr(entity, '__code__'): + raise ValueError('Cannot apply autograph to a function that doesn\'t ' + 'expose a __code__ object. If this is a @tf.function,' + ' try passing f.python_function instead.') + free_nonglobal_var_names = entity.__code__.co_freevars for i, name in enumerate(free_nonglobal_var_names): if (name == 'ag__' and @@ -505,22 +493,7 @@ def convert_entity_to_ast(o, program_ctx): """ logging.log(1, 'Converting %s', o) - if tf_inspect.isclass(o): - nodes, name, entity_info = convert_class_to_ast(o, program_ctx) - elif tf_inspect.isfunction(o): - nodes, name, entity_info = convert_func_to_ast(o, program_ctx) - elif tf_inspect.ismethod(o): - nodes, name, entity_info = convert_func_to_ast(o, program_ctx) - elif hasattr(o, '__class__'): - # Note: this should only be raised when attempting to convert the object - # directly. converted_call should still support it. - raise NotImplementedError( - 'cannot convert entity "{}": object conversion is not yet' - ' supported.'.format(o)) - else: - raise NotImplementedError( - 'Entity "%s" has unsupported type "%s". Only functions and classes are ' - 'supported for now.' % (o, type(o))) + nodes, name, entity_info = convert_func_to_ast(o, program_ctx) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) @@ -532,106 +505,6 @@ def convert_entity_to_ast(o, program_ctx): return nodes, name, entity_info -def convert_class_to_ast(c, program_ctx): - """Specialization of `convert_entity_to_ast` for classes.""" - # TODO(mdan): Revisit this altogether. Not sure we still need it. - converted_members = {} - method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) - members = tf_inspect.getmembers(c, predicate=method_filter) - if not members: - raise ValueError('cannot convert %s: no member methods' % c) - - # TODO(mdan): Don't clobber namespaces for each method in one class namespace. - # The assumption that one namespace suffices for all methods only holds if - # all methods were defined in the same module. - # If, instead, functions are imported from multiple modules and then spliced - # into the class, then each function has its own globals and __future__ - # imports that need to stay separate. - - # For example, C's methods could both have `global x` statements referring to - # mod1.x and mod2.x, but using one namespace for C would cause a conflict. - # from mod1 import f1 - # from mod2 import f2 - # class C(object): - # method1 = f1 - # method2 = f2 - - class_namespace = {} - future_features = None - for _, m in members: - # Only convert the members that are directly defined by the class. - if inspect_utils.getdefiningclass(m, c) is not c: - continue - (node,), _, entity_info = convert_func_to_ast( - m, program_ctx=program_ctx, do_rename=False) - class_namespace.update(entity_info.namespace) - converted_members[m] = node - - # TODO(mdan): Similarly check the globals. - if future_features is None: - future_features = entity_info.future_features - elif frozenset(future_features) ^ frozenset(entity_info.future_features): - # Note: we can support this case if ever needed. - raise ValueError( - 'cannot convert {}: if has methods built with mismatched future' - ' features: {} and {}'.format(c, future_features, - entity_info.future_features)) - namer = naming.Namer(class_namespace) - class_name = namer.class_name(c.__name__) - - # Process any base classes: if the superclass if of a whitelisted type, an - # absolute import line is generated. - output_nodes = [] - renames = {} - base_names = [] - for base in c.__bases__: - if isinstance(object, base): - base_names.append('object') - continue - if is_whitelisted(base): - alias = namer.new_symbol(base.__name__, ()) - output_nodes.append( - gast.ImportFrom( - module=base.__module__, - names=[gast.alias(name=base.__name__, asname=alias)], - level=0)) - else: - raise NotImplementedError( - 'Conversion of classes that do not directly extend classes from' - ' whitelisted modules is temporarily suspended. If this breaks' - ' existing code please notify the AutoGraph team immediately.') - base_names.append(alias) - renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - - # Generate the definition of the converted class. - bases = [ - gast.Name(n, ctx=gast.Load(), annotation=None, type_comment=None) - for n in base_names] - class_def = gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[]) - # Make a final pass to replace references to the class or its base classes. - # Most commonly, this occurs when making super().__init__() calls. - # TODO(mdan): Making direct references to superclass' superclass will fail. - class_def = qual_names.resolve(class_def) - renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - class_def = ast_util.rename_symbols(class_def, renames) - - output_nodes.append(class_def) - - # TODO(mdan): Find a way better than forging this object. - entity_info = transformer.EntityInfo( - source_code=None, - source_file=None, - future_features=future_features, - namespace=class_namespace) - - return output_nodes, class_name, entity_info - - def _add_reserved_symbol(namespace, name, entity): if name not in namespace: namespace[name] = entity diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index 2453a51993c..b0c1e45cc45 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -36,7 +36,6 @@ from tensorflow.python.autograph.impl.testing import pybind_for_testing from tensorflow.python.autograph.pyct import parser from tensorflow.python.eager import function from tensorflow.python.framework import constant_op -from tensorflow.python.keras.engine import training from tensorflow.python.platform import test @@ -127,11 +126,6 @@ class ConversionTest(test.TestCase): # Note: currently, native bindings are whitelisted by a separate check. self.assertFalse(conversion.is_whitelisted(test_object.method)) - def test_convert_entity_to_ast_unsupported_types(self): - with self.assertRaises(NotImplementedError): - program_ctx = self._simple_program_ctx() - conversion.convert_entity_to_ast('dummy', program_ctx) - def test_convert_entity_to_ast_callable(self): b = 2 @@ -174,53 +168,6 @@ class ConversionTest(test.TestCase): f_node, = nodes self.assertEqual('tf__f', f_node.name) - def test_convert_entity_to_ast_class_hierarchy(self): - - class TestBase(object): - - def __init__(self, x='base'): - self.x = x - - def foo(self): - return self.x - - def bar(self): - return self.x - - class TestSubclass(TestBase): - - def __init__(self, y): - super(TestSubclass, self).__init__('sub') - self.y = y - - def foo(self): - return self.y - - def baz(self): - return self.y - - program_ctx = self._simple_program_ctx() - with self.assertRaisesRegex(NotImplementedError, 'classes.*whitelisted'): - conversion.convert_entity_to_ast(TestSubclass, program_ctx) - - def test_convert_entity_to_ast_class_hierarchy_whitelisted(self): - - class TestSubclass(training.Model): - - def __init__(self, y): - super(TestSubclass, self).__init__() - self.built = False - - def call(self, x): - return 3 * x - - program_ctx = self._simple_program_ctx() - (import_node, class_node), name, _ = conversion.convert_entity_to_ast( - TestSubclass, program_ctx) - self.assertEqual(import_node.names[0].name, 'Model') - self.assertEqual(name, 'TfTestSubclass') - self.assertEqual(class_node.name, 'TfTestSubclass') - def test_convert_entity_to_ast_lambda(self): b = 2 f = lambda x: b * x if x > 0 else -x