Reduce the cost of serializing ConversionOptions to code, by using a more efficient inspect.util.getqualifiedname, reducing its max_depth and falling back to caching the value in the namespace. The latter step makes it more difficult to run the generated code afterwards, but it should in turn speed up the conversion process. This also adds an extra check to tf_decorator to improve robustness.
PiperOrigin-RevId: 225226256
This commit is contained in:
parent
3ae0654d41
commit
350791003d
@ -261,7 +261,7 @@ class CallTreeTransformer(converter.Base):
|
||||
func=func,
|
||||
owner=owner,
|
||||
options=self.ctx.program.options.to_ast(
|
||||
self.ctx.info.namespace,
|
||||
self.ctx,
|
||||
internal_convert_user_code=self.ctx.program.options.recursive),
|
||||
args=node.args)
|
||||
# TODO(mdan): Improve the template mechanism to better support this.
|
||||
|
@ -179,15 +179,14 @@ class ConversionOptions(object):
|
||||
return (Feature.ALL in self.optional_features or
|
||||
feature in self.optional_features)
|
||||
|
||||
def to_ast(self, namespace, internal_convert_user_code=None):
|
||||
def to_ast(self, ctx, internal_convert_user_code=None):
|
||||
"""Returns a representation of this object as an AST node.
|
||||
|
||||
The AST node encodes a constructor that would create an object with the
|
||||
same contents.
|
||||
|
||||
Args:
|
||||
namespace: Dict[str, Any], the namespace to use when serializing values to
|
||||
names.
|
||||
ctx: EntityContext, the entity with which this AST needs to be consistent.
|
||||
internal_convert_user_code: Optional[bool], allows ovrriding the
|
||||
corresponding value.
|
||||
|
||||
@ -205,10 +204,11 @@ class ConversionOptions(object):
|
||||
"""
|
||||
|
||||
def as_qualified_name(o):
|
||||
name = inspect_utils.getqualifiedname(namespace, o)
|
||||
name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1)
|
||||
if not name:
|
||||
raise ValueError('Could not locate entity {} in {}'.format(
|
||||
o, namespace))
|
||||
# TODO(mdan): This needs to account for the symbols defined locally.
|
||||
name = ctx.namer.new_symbol(o.__name__, ())
|
||||
ctx.program.add_symbol(name, o)
|
||||
return name
|
||||
|
||||
def list_of_names(values):
|
||||
@ -279,6 +279,7 @@ class ProgramContext(object):
|
||||
self.dependency_cache = {}
|
||||
self.additional_imports = set()
|
||||
self.name_map = {}
|
||||
self.additional_symbols = {}
|
||||
|
||||
@property
|
||||
def required_imports(self):
|
||||
@ -321,6 +322,11 @@ class ProgramContext(object):
|
||||
else:
|
||||
self.name_map[o] = name
|
||||
|
||||
def add_symbol(self, name, value):
|
||||
if name in self.additional_symbols:
|
||||
assert self.additional_symbols[name] is value
|
||||
self.additional_symbols[name] = value
|
||||
|
||||
def add_to_cache(self, original_entity, converted_ast):
|
||||
self.conversion_order.append(original_entity)
|
||||
self.dependency_cache[original_entity] = converted_ast
|
||||
|
@ -424,6 +424,9 @@ def to_graph(entity,
|
||||
# Avoid overwriting entities that have been transformed.
|
||||
if key not in compiled_module.__dict__:
|
||||
compiled_module.__dict__[key] = val
|
||||
for key, val in program_ctx.additional_symbols.items():
|
||||
if key not in compiled_module.__dict__:
|
||||
compiled_module.__dict__[key] = val
|
||||
compiled = getattr(compiled_module, name)
|
||||
|
||||
if tf_inspect.isfunction(entity):
|
||||
|
@ -101,7 +101,7 @@ def getnamespace(f):
|
||||
return namespace
|
||||
|
||||
|
||||
def getqualifiedname(namespace, object_, max_depth=2):
|
||||
def getqualifiedname(namespace, object_, max_depth=7, visited=None):
|
||||
"""Returns the name by which a value can be referred to in a given namespace.
|
||||
|
||||
If the object defines a parent module, the function attempts to use it to
|
||||
@ -115,16 +115,20 @@ def getqualifiedname(namespace, object_, max_depth=2):
|
||||
object_: Any, the value to search.
|
||||
max_depth: Optional[int], a limit to the recursion depth when searching
|
||||
inside modules.
|
||||
visited: Optional[Set[int]], ID of modules to avoid visiting.
|
||||
Returns: Union[str, None], the fully-qualified name that resolves to the value
|
||||
o, or None if it couldn't be found.
|
||||
"""
|
||||
for name, value in namespace.items():
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
for name in namespace:
|
||||
# The value may be referenced by more than one symbol, case in which
|
||||
# any symbol will be fine. If the program contains symbol aliases that
|
||||
# change over time, this may capture a symbol that will later point to
|
||||
# something else.
|
||||
# TODO(mdan): Prefer the symbol that matches the value type name.
|
||||
if object_ is value:
|
||||
if object_ is namespace[name]:
|
||||
return name
|
||||
|
||||
# If an object is not found, try to search its parent modules.
|
||||
@ -132,22 +136,25 @@ def getqualifiedname(namespace, object_, max_depth=2):
|
||||
if (parent is not None and parent is not object_ and
|
||||
parent is not namespace):
|
||||
# No limit to recursion depth because of the guard above.
|
||||
parent_name = getqualifiedname(namespace, parent, max_depth=0)
|
||||
parent_name = getqualifiedname(
|
||||
namespace, parent, max_depth=0, visited=visited)
|
||||
if parent_name is not None:
|
||||
name_in_parent = getqualifiedname(parent.__dict__, object_, max_depth=0)
|
||||
name_in_parent = getqualifiedname(
|
||||
parent.__dict__, object_, max_depth=0, visited=visited)
|
||||
assert name_in_parent is not None, (
|
||||
'An object should always be found in its owner module')
|
||||
return '{}.{}'.format(parent_name, name_in_parent)
|
||||
|
||||
# TODO(mdan): Use breadth-first search and avoid visiting modules twice.
|
||||
if max_depth:
|
||||
# Iterating over a copy prevents "changed size due to iteration" errors.
|
||||
# It's unclear why those occur - suspecting new modules may load during
|
||||
# iteration.
|
||||
for name, value in namespace.copy().items():
|
||||
if tf_inspect.ismodule(value):
|
||||
for name in tuple(namespace.keys()):
|
||||
value = namespace[name]
|
||||
if tf_inspect.ismodule(value) and id(value) not in visited:
|
||||
visited.add(id(value))
|
||||
name_in_module = getqualifiedname(value.__dict__, object_,
|
||||
max_depth - 1)
|
||||
max_depth - 1, visited)
|
||||
if name_in_module is not None:
|
||||
return '{}.{}'.format(name, name_in_module)
|
||||
return None
|
||||
|
@ -183,6 +183,63 @@ class InspectUtilsTest(test.TestCase):
|
||||
self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
|
||||
self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
|
||||
|
||||
def test_getqualifiedname_efficiency(self):
|
||||
foo = object()
|
||||
|
||||
# We create a densely connected graph consisting of a relatively small
|
||||
# number of modules and hide our symbol in one of them. The path to the
|
||||
# symbol is at least 10, and each node has about 10 neighbors. However,
|
||||
# by skipping visited modules, the search should take much less.
|
||||
ns = {}
|
||||
prev_level = []
|
||||
for i in range(10):
|
||||
current_level = []
|
||||
for j in range(10):
|
||||
mod_name = 'mod_{}_{}'.format(i, j)
|
||||
mod = imp.new_module(mod_name)
|
||||
current_level.append(mod)
|
||||
if i == 9 and j == 9:
|
||||
mod.foo = foo
|
||||
if prev_level:
|
||||
# All modules at level i refer to all modules at level i+1
|
||||
for prev in prev_level:
|
||||
for mod in current_level:
|
||||
prev.__dict__[mod.__name__] = mod
|
||||
else:
|
||||
for mod in current_level:
|
||||
ns[mod.__name__] = mod
|
||||
prev_level = current_level
|
||||
|
||||
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
|
||||
self.assertIsNotNone(
|
||||
inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
|
||||
|
||||
def test_getqualifiedname_cycles(self):
|
||||
foo = object()
|
||||
|
||||
# We create a graph of modules that contains circular references. The
|
||||
# search process should avoid them. The searched object is hidden at the
|
||||
# bottom of a path of length roughly 10.
|
||||
ns = {}
|
||||
mods = []
|
||||
for i in range(10):
|
||||
mod = imp.new_module('mod_{}'.format(i))
|
||||
if i == 9:
|
||||
mod.foo = foo
|
||||
# Module i refers to module i+1
|
||||
if mods:
|
||||
mods[-1].__dict__[mod.__name__] = mod
|
||||
else:
|
||||
ns[mod.__name__] = mod
|
||||
# Module i refers to all modules j < i.
|
||||
for prev in mods:
|
||||
mod.__dict__[prev.__name__] = prev
|
||||
mods.append(mod)
|
||||
|
||||
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
|
||||
self.assertIsNotNone(
|
||||
inspect_utils.getqualifiedname(ns, foo, max_depth=10000000000))
|
||||
|
||||
def test_getqualifiedname_finds_via_parent_module(self):
|
||||
# TODO(mdan): This test is vulnerable to change in the lib module.
|
||||
# A better way to forge modules should be found.
|
||||
|
@ -98,6 +98,9 @@ def make_decorator(target,
|
||||
if hasattr(target, '__doc__'):
|
||||
decorator_func.__doc__ = decorator.__doc__
|
||||
decorator_func.__wrapped__ = target
|
||||
# Keeping a second handle to `target` allows callers to detect whether the
|
||||
# decorator was modified using `rewrap`.
|
||||
decorator_func.__original_wrapped__ = target
|
||||
return decorator_func
|
||||
|
||||
|
||||
@ -173,6 +176,8 @@ def unwrap(maybe_tf_decorator):
|
||||
decorators.append(getattr(cur, '_tf_decorator'))
|
||||
else:
|
||||
break
|
||||
if not hasattr(decorators[-1], 'decorated_target'):
|
||||
break
|
||||
cur = decorators[-1].decorated_target
|
||||
return decorators, cur
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user