Internal cleanup: retire support for converting whole classes, which is not supported in TF2 and is unlikely to have real uses.

PiperOrigin-RevId: 302550027
Change-Id: I0127d56a3cc45ea38205df9a88480a75d25ea39c
This commit is contained in:
Dan Moldovan 2020-03-23 16:59:53 -07:00 committed by TensorFlower Gardener
parent ec5c9bebf8
commit b5c7256118
2 changed files with 13 additions and 193 deletions

View File

@ -52,12 +52,11 @@ from tensorflow.python.autograph.core import naming
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.utils import ag_logging as logging
@ -299,15 +298,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
"""Creates a converted instance and binds it to match original entity."""
factory = converted_entity_info.get_factory()
# `factory` is currently bound to the empty module it was loaded from.
# It must instead be bound to the globals and closure from the original
# entity.
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
entity_globals = entity.__globals__
entity_closure = entity.__closure__ or ()
elif hasattr(entity, '__module__'):
entity_globals = sys.modules[entity.__module__].__dict__
entity_closure = ()
assert len(entity_closure) == len(free_nonglobal_var_names)
# Fit the original entity's cells to match the order of factory's cells.
@ -328,7 +320,6 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
ag_internal, converted_entity_info.source_map,
converted_entity_info.get_module())
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
# Attach the default argument to the converted function.
converted_entity.__defaults__ = entity.__defaults__
if hasattr(entity, '__kwdefaults__'):
@ -340,14 +331,11 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
def convert(entity, program_ctx):
"""Converts an entity into an equivalent entity."""
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
if not hasattr(entity, '__code__'):
raise ValueError('Cannot apply autograph to a function that doesn\'t '
'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.')
free_nonglobal_var_names = entity.__code__.co_freevars
else:
free_nonglobal_var_names = ()
for i, name in enumerate(free_nonglobal_var_names):
if (name == 'ag__' and
@ -505,22 +493,7 @@ def convert_entity_to_ast(o, program_ctx):
"""
logging.log(1, 'Converting %s', o)
if tf_inspect.isclass(o):
nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
elif tf_inspect.isfunction(o):
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
elif tf_inspect.ismethod(o):
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
elif hasattr(o, '__class__'):
# Note: this should only be raised when attempting to convert the object
# directly. converted_call should still support it.
raise NotImplementedError(
'cannot convert entity "{}": object conversion is not yet'
' supported.'.format(o))
else:
raise NotImplementedError(
'Entity "%s" has unsupported type "%s". Only functions and classes are '
'supported for now.' % (o, type(o)))
if logging.has_verbosity(2):
logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
@ -532,106 +505,6 @@ def convert_entity_to_ast(o, program_ctx):
return nodes, name, entity_info
def convert_class_to_ast(c, program_ctx):
"""Specialization of `convert_entity_to_ast` for classes."""
# TODO(mdan): Revisit this altogether. Not sure we still need it.
converted_members = {}
method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
members = tf_inspect.getmembers(c, predicate=method_filter)
if not members:
raise ValueError('cannot convert %s: no member methods' % c)
# TODO(mdan): Don't clobber namespaces for each method in one class namespace.
# The assumption that one namespace suffices for all methods only holds if
# all methods were defined in the same module.
# If, instead, functions are imported from multiple modules and then spliced
# into the class, then each function has its own globals and __future__
# imports that need to stay separate.
# For example, C's methods could both have `global x` statements referring to
# mod1.x and mod2.x, but using one namespace for C would cause a conflict.
# from mod1 import f1
# from mod2 import f2
# class C(object):
# method1 = f1
# method2 = f2
class_namespace = {}
future_features = None
for _, m in members:
# Only convert the members that are directly defined by the class.
if inspect_utils.getdefiningclass(m, c) is not c:
continue
(node,), _, entity_info = convert_func_to_ast(
m, program_ctx=program_ctx, do_rename=False)
class_namespace.update(entity_info.namespace)
converted_members[m] = node
# TODO(mdan): Similarly check the globals.
if future_features is None:
future_features = entity_info.future_features
elif frozenset(future_features) ^ frozenset(entity_info.future_features):
# Note: we can support this case if ever needed.
raise ValueError(
'cannot convert {}: if has methods built with mismatched future'
' features: {} and {}'.format(c, future_features,
entity_info.future_features))
namer = naming.Namer(class_namespace)
class_name = namer.class_name(c.__name__)
# Process any base classes: if the superclass if of a whitelisted type, an
# absolute import line is generated.
output_nodes = []
renames = {}
base_names = []
for base in c.__bases__:
if isinstance(object, base):
base_names.append('object')
continue
if is_whitelisted(base):
alias = namer.new_symbol(base.__name__, ())
output_nodes.append(
gast.ImportFrom(
module=base.__module__,
names=[gast.alias(name=base.__name__, asname=alias)],
level=0))
else:
raise NotImplementedError(
'Conversion of classes that do not directly extend classes from'
' whitelisted modules is temporarily suspended. If this breaks'
' existing code please notify the AutoGraph team immediately.')
base_names.append(alias)
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
# Generate the definition of the converted class.
bases = [
gast.Name(n, ctx=gast.Load(), annotation=None, type_comment=None)
for n in base_names]
class_def = gast.ClassDef(
class_name,
bases=bases,
keywords=[],
body=list(converted_members.values()),
decorator_list=[])
# Make a final pass to replace references to the class or its base classes.
# Most commonly, this occurs when making super().__init__() calls.
# TODO(mdan): Making direct references to superclass' superclass will fail.
class_def = qual_names.resolve(class_def)
renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
class_def = ast_util.rename_symbols(class_def, renames)
output_nodes.append(class_def)
# TODO(mdan): Find a way better than forging this object.
entity_info = transformer.EntityInfo(
source_code=None,
source_file=None,
future_features=future_features,
namespace=class_namespace)
return output_nodes, class_name, entity_info
def _add_reserved_symbol(namespace, name, entity):
if name not in namespace:
namespace[name] = entity

View File

@ -36,7 +36,6 @@ from tensorflow.python.autograph.impl.testing import pybind_for_testing
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.keras.engine import training
from tensorflow.python.platform import test
@ -127,11 +126,6 @@ class ConversionTest(test.TestCase):
# Note: currently, native bindings are whitelisted by a separate check.
self.assertFalse(conversion.is_whitelisted(test_object.method))
def test_convert_entity_to_ast_unsupported_types(self):
with self.assertRaises(NotImplementedError):
program_ctx = self._simple_program_ctx()
conversion.convert_entity_to_ast('dummy', program_ctx)
def test_convert_entity_to_ast_callable(self):
b = 2
@ -174,53 +168,6 @@ class ConversionTest(test.TestCase):
f_node, = nodes
self.assertEqual('tf__f', f_node.name)
def test_convert_entity_to_ast_class_hierarchy(self):
class TestBase(object):
def __init__(self, x='base'):
self.x = x
def foo(self):
return self.x
def bar(self):
return self.x
class TestSubclass(TestBase):
def __init__(self, y):
super(TestSubclass, self).__init__('sub')
self.y = y
def foo(self):
return self.y
def baz(self):
return self.y
program_ctx = self._simple_program_ctx()
with self.assertRaisesRegex(NotImplementedError, 'classes.*whitelisted'):
conversion.convert_entity_to_ast(TestSubclass, program_ctx)
def test_convert_entity_to_ast_class_hierarchy_whitelisted(self):
class TestSubclass(training.Model):
def __init__(self, y):
super(TestSubclass, self).__init__()
self.built = False
def call(self, x):
return 3 * x
program_ctx = self._simple_program_ctx()
(import_node, class_node), name, _ = conversion.convert_entity_to_ast(
TestSubclass, program_ctx)
self.assertEqual(import_node.names[0].name, 'Model')
self.assertEqual(name, 'TfTestSubclass')
self.assertEqual(class_node.name, 'TfTestSubclass')
def test_convert_entity_to_ast_lambda(self):
b = 2
f = lambda x: b * x if x > 0 else -x