Add a utility that allows finding a name for an entity, relative to an existing namespace.
PiperOrigin-RevId: 216211286
This commit is contained in:
parent
8ef3e7c8c0
commit
d1588d72a8
@ -67,6 +67,40 @@ def getnamespace(f):
|
||||
return namespace
|
||||
|
||||
|
||||
def getqualifiedname(namespace, object_, max_depth=2):
|
||||
"""Returns the name by which a value can be referred to in a given namespace.
|
||||
|
||||
This function will recurse inside modules, but it will not search objects for
|
||||
attributes. The recursion depth is controlled by max_depth.
|
||||
|
||||
Args:
|
||||
namespace: Dict[str, Any], the namespace to search into.
|
||||
object_: Any, the value to search.
|
||||
max_depth: Optional[int], a limit to the recursion depth when searching
|
||||
inside modules.
|
||||
Returns: Union[str, None], the fully-qualified name that resolves to the value
|
||||
o, or None if it couldn't be found.
|
||||
"""
|
||||
for name, value in namespace.items():
|
||||
# The value may be referenced by more than one symbol, case in which
|
||||
# any symbol will be fine. If the program contains symbol aliases that
|
||||
# change over time, this may capture a symbol that will later point to
|
||||
# something else.
|
||||
# TODO(mdan): Prefer the symbol that matches the value type name.
|
||||
if object_ is value:
|
||||
return name
|
||||
|
||||
# TODO(mdan): Use breadth-first search and avoid visiting modules twice.
|
||||
if max_depth:
|
||||
for name, value in namespace.items():
|
||||
if tf_inspect.ismodule(value):
|
||||
name_in_module = getqualifiedname(value.__dict__, object_,
|
||||
max_depth - 1)
|
||||
if name_in_module is not None:
|
||||
return '{}.{}'.format(name, name_in_module)
|
||||
return None
|
||||
|
||||
|
||||
def _get_unbound_function(m):
|
||||
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
|
||||
# The failure case is for tf.keras.Model.
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import wraps
|
||||
import imp
|
||||
|
||||
import six
|
||||
|
||||
@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
|
||||
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
|
||||
self.assertTrue('local_var' not in ns)
|
||||
|
||||
def test_getqualifiedname(self):
|
||||
foo = object()
|
||||
qux = imp.new_module('quxmodule')
|
||||
bar = imp.new_module('barmodule')
|
||||
baz = object()
|
||||
bar.baz = baz
|
||||
|
||||
ns = {
|
||||
'foo': foo,
|
||||
'bar': bar,
|
||||
'qux': qux,
|
||||
}
|
||||
|
||||
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
|
||||
self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
|
||||
self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
|
||||
self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
|
||||
|
||||
def test_getmethodclass(self):
|
||||
|
||||
self.assertEqual(
|
||||
|
Loading…
Reference in New Issue
Block a user