Add a utility that allows finding a name for an entity, relative to an existing namespace.

PiperOrigin-RevId: 216211286
This commit is contained in:
Dan Moldovan 2018-10-08 10:43:03 -07:00 committed by TensorFlower Gardener
parent 8ef3e7c8c0
commit d1588d72a8
2 changed files with 53 additions and 0 deletions

View File

@ -67,6 +67,40 @@ def getnamespace(f):
return namespace
def getqualifiedname(namespace, object_, max_depth=2):
"""Returns the name by which a value can be referred to in a given namespace.
This function will recurse inside modules, but it will not search objects for
attributes. The recursion depth is controlled by max_depth.
Args:
namespace: Dict[str, Any], the namespace to search into.
object_: Any, the value to search.
max_depth: Optional[int], a limit to the recursion depth when searching
inside modules.
Returns: Union[str, None], the fully-qualified name that resolves to the value
o, or None if it couldn't be found.
"""
for name, value in namespace.items():
# The value may be referenced by more than one symbol, case in which
# any symbol will be fine. If the program contains symbol aliases that
# change over time, this may capture a symbol that will later point to
# something else.
# TODO(mdan): Prefer the symbol that matches the value type name.
if object_ is value:
return name
# TODO(mdan): Use breadth-first search and avoid visiting modules twice.
if max_depth:
for name, value in namespace.items():
if tf_inspect.ismodule(value):
name_in_module = getqualifiedname(value.__dict__, object_,
max_depth - 1)
if name_in_module is not None:
return '{}.{}'.format(name, name_in_module)
return None
def _get_unbound_function(m):
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
# The failure case is for tf.keras.Model.

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
import imp
import six
@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
self.assertTrue('local_var' not in ns)
def test_getqualifiedname(self):
foo = object()
qux = imp.new_module('quxmodule')
bar = imp.new_module('barmodule')
baz = object()
bar.baz = baz
ns = {
'foo': foo,
'bar': bar,
'qux': qux,
}
self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
def test_getmethodclass(self):
self.assertEqual(