Internal cleanup: retire support for converting whole classes, which is not supported in TF2 and is unlikely to have real uses.
PiperOrigin-RevId: 302550027 Change-Id: I0127d56a3cc45ea38205df9a88480a75d25ea39c
This commit is contained in:
parent
ec5c9bebf8
commit
b5c7256118
@ -52,12 +52,11 @@ from tensorflow.python.autograph.core import naming
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
@ -299,15 +298,8 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
||||
"""Creates a converted instance and binds it to match original entity."""
|
||||
factory = converted_entity_info.get_factory()
|
||||
|
||||
# `factory` is currently bound to the empty module it was loaded from.
|
||||
# It must instead be bound to the globals and closure from the original
|
||||
# entity.
|
||||
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
||||
entity_globals = entity.__globals__
|
||||
entity_closure = entity.__closure__ or ()
|
||||
elif hasattr(entity, '__module__'):
|
||||
entity_globals = sys.modules[entity.__module__].__dict__
|
||||
entity_closure = ()
|
||||
entity_globals = entity.__globals__
|
||||
entity_closure = entity.__closure__ or ()
|
||||
assert len(entity_closure) == len(free_nonglobal_var_names)
|
||||
|
||||
# Fit the original entity's cells to match the order of factory's cells.
|
||||
@ -328,11 +320,10 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
||||
ag_internal, converted_entity_info.source_map,
|
||||
converted_entity_info.get_module())
|
||||
|
||||
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
||||
# Attach the default argument to the converted function.
|
||||
converted_entity.__defaults__ = entity.__defaults__
|
||||
if hasattr(entity, '__kwdefaults__'):
|
||||
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
||||
# Attach the default argument to the converted function.
|
||||
converted_entity.__defaults__ = entity.__defaults__
|
||||
if hasattr(entity, '__kwdefaults__'):
|
||||
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
||||
|
||||
return converted_entity
|
||||
|
||||
@ -340,14 +331,11 @@ def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
||||
def convert(entity, program_ctx):
|
||||
"""Converts an entity into an equivalent entity."""
|
||||
|
||||
if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
|
||||
if not hasattr(entity, '__code__'):
|
||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||
'expose a __code__ object. If this is a @tf.function,'
|
||||
' try passing f.python_function instead.')
|
||||
free_nonglobal_var_names = entity.__code__.co_freevars
|
||||
else:
|
||||
free_nonglobal_var_names = ()
|
||||
if not hasattr(entity, '__code__'):
|
||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||
'expose a __code__ object. If this is a @tf.function,'
|
||||
' try passing f.python_function instead.')
|
||||
free_nonglobal_var_names = entity.__code__.co_freevars
|
||||
|
||||
for i, name in enumerate(free_nonglobal_var_names):
|
||||
if (name == 'ag__' and
|
||||
@ -505,22 +493,7 @@ def convert_entity_to_ast(o, program_ctx):
|
||||
"""
|
||||
logging.log(1, 'Converting %s', o)
|
||||
|
||||
if tf_inspect.isclass(o):
|
||||
nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
|
||||
elif tf_inspect.isfunction(o):
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
elif tf_inspect.ismethod(o):
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
elif hasattr(o, '__class__'):
|
||||
# Note: this should only be raised when attempting to convert the object
|
||||
# directly. converted_call should still support it.
|
||||
raise NotImplementedError(
|
||||
'cannot convert entity "{}": object conversion is not yet'
|
||||
' supported.'.format(o))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Entity "%s" has unsupported type "%s". Only functions and classes are '
|
||||
'supported for now.' % (o, type(o)))
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
|
||||
@ -532,106 +505,6 @@ def convert_entity_to_ast(o, program_ctx):
|
||||
return nodes, name, entity_info
|
||||
|
||||
|
||||
def convert_class_to_ast(c, program_ctx):
|
||||
"""Specialization of `convert_entity_to_ast` for classes."""
|
||||
# TODO(mdan): Revisit this altogether. Not sure we still need it.
|
||||
converted_members = {}
|
||||
method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
|
||||
members = tf_inspect.getmembers(c, predicate=method_filter)
|
||||
if not members:
|
||||
raise ValueError('cannot convert %s: no member methods' % c)
|
||||
|
||||
# TODO(mdan): Don't clobber namespaces for each method in one class namespace.
|
||||
# The assumption that one namespace suffices for all methods only holds if
|
||||
# all methods were defined in the same module.
|
||||
# If, instead, functions are imported from multiple modules and then spliced
|
||||
# into the class, then each function has its own globals and __future__
|
||||
# imports that need to stay separate.
|
||||
|
||||
# For example, C's methods could both have `global x` statements referring to
|
||||
# mod1.x and mod2.x, but using one namespace for C would cause a conflict.
|
||||
# from mod1 import f1
|
||||
# from mod2 import f2
|
||||
# class C(object):
|
||||
# method1 = f1
|
||||
# method2 = f2
|
||||
|
||||
class_namespace = {}
|
||||
future_features = None
|
||||
for _, m in members:
|
||||
# Only convert the members that are directly defined by the class.
|
||||
if inspect_utils.getdefiningclass(m, c) is not c:
|
||||
continue
|
||||
(node,), _, entity_info = convert_func_to_ast(
|
||||
m, program_ctx=program_ctx, do_rename=False)
|
||||
class_namespace.update(entity_info.namespace)
|
||||
converted_members[m] = node
|
||||
|
||||
# TODO(mdan): Similarly check the globals.
|
||||
if future_features is None:
|
||||
future_features = entity_info.future_features
|
||||
elif frozenset(future_features) ^ frozenset(entity_info.future_features):
|
||||
# Note: we can support this case if ever needed.
|
||||
raise ValueError(
|
||||
'cannot convert {}: if has methods built with mismatched future'
|
||||
' features: {} and {}'.format(c, future_features,
|
||||
entity_info.future_features))
|
||||
namer = naming.Namer(class_namespace)
|
||||
class_name = namer.class_name(c.__name__)
|
||||
|
||||
# Process any base classes: if the superclass if of a whitelisted type, an
|
||||
# absolute import line is generated.
|
||||
output_nodes = []
|
||||
renames = {}
|
||||
base_names = []
|
||||
for base in c.__bases__:
|
||||
if isinstance(object, base):
|
||||
base_names.append('object')
|
||||
continue
|
||||
if is_whitelisted(base):
|
||||
alias = namer.new_symbol(base.__name__, ())
|
||||
output_nodes.append(
|
||||
gast.ImportFrom(
|
||||
module=base.__module__,
|
||||
names=[gast.alias(name=base.__name__, asname=alias)],
|
||||
level=0))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Conversion of classes that do not directly extend classes from'
|
||||
' whitelisted modules is temporarily suspended. If this breaks'
|
||||
' existing code please notify the AutoGraph team immediately.')
|
||||
base_names.append(alias)
|
||||
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
|
||||
|
||||
# Generate the definition of the converted class.
|
||||
bases = [
|
||||
gast.Name(n, ctx=gast.Load(), annotation=None, type_comment=None)
|
||||
for n in base_names]
|
||||
class_def = gast.ClassDef(
|
||||
class_name,
|
||||
bases=bases,
|
||||
keywords=[],
|
||||
body=list(converted_members.values()),
|
||||
decorator_list=[])
|
||||
# Make a final pass to replace references to the class or its base classes.
|
||||
# Most commonly, this occurs when making super().__init__() calls.
|
||||
# TODO(mdan): Making direct references to superclass' superclass will fail.
|
||||
class_def = qual_names.resolve(class_def)
|
||||
renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
|
||||
class_def = ast_util.rename_symbols(class_def, renames)
|
||||
|
||||
output_nodes.append(class_def)
|
||||
|
||||
# TODO(mdan): Find a way better than forging this object.
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=None,
|
||||
source_file=None,
|
||||
future_features=future_features,
|
||||
namespace=class_namespace)
|
||||
|
||||
return output_nodes, class_name, entity_info
|
||||
|
||||
|
||||
def _add_reserved_symbol(namespace, name, entity):
|
||||
if name not in namespace:
|
||||
namespace[name] = entity
|
||||
|
@ -36,7 +36,6 @@ from tensorflow.python.autograph.impl.testing import pybind_for_testing
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -127,11 +126,6 @@ class ConversionTest(test.TestCase):
|
||||
# Note: currently, native bindings are whitelisted by a separate check.
|
||||
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
||||
|
||||
def test_convert_entity_to_ast_unsupported_types(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
program_ctx = self._simple_program_ctx()
|
||||
conversion.convert_entity_to_ast('dummy', program_ctx)
|
||||
|
||||
def test_convert_entity_to_ast_callable(self):
|
||||
b = 2
|
||||
|
||||
@ -174,53 +168,6 @@ class ConversionTest(test.TestCase):
|
||||
f_node, = nodes
|
||||
self.assertEqual('tf__f', f_node.name)
|
||||
|
||||
def test_convert_entity_to_ast_class_hierarchy(self):
|
||||
|
||||
class TestBase(object):
|
||||
|
||||
def __init__(self, x='base'):
|
||||
self.x = x
|
||||
|
||||
def foo(self):
|
||||
return self.x
|
||||
|
||||
def bar(self):
|
||||
return self.x
|
||||
|
||||
class TestSubclass(TestBase):
|
||||
|
||||
def __init__(self, y):
|
||||
super(TestSubclass, self).__init__('sub')
|
||||
self.y = y
|
||||
|
||||
def foo(self):
|
||||
return self.y
|
||||
|
||||
def baz(self):
|
||||
return self.y
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
with self.assertRaisesRegex(NotImplementedError, 'classes.*whitelisted'):
|
||||
conversion.convert_entity_to_ast(TestSubclass, program_ctx)
|
||||
|
||||
def test_convert_entity_to_ast_class_hierarchy_whitelisted(self):
|
||||
|
||||
class TestSubclass(training.Model):
|
||||
|
||||
def __init__(self, y):
|
||||
super(TestSubclass, self).__init__()
|
||||
self.built = False
|
||||
|
||||
def call(self, x):
|
||||
return 3 * x
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
(import_node, class_node), name, _ = conversion.convert_entity_to_ast(
|
||||
TestSubclass, program_ctx)
|
||||
self.assertEqual(import_node.names[0].name, 'Model')
|
||||
self.assertEqual(name, 'TfTestSubclass')
|
||||
self.assertEqual(class_node.name, 'TfTestSubclass')
|
||||
|
||||
def test_convert_entity_to_ast_lambda(self):
|
||||
b = 2
|
||||
f = lambda x: b * x if x > 0 else -x
|
||||
|
Loading…
Reference in New Issue
Block a user