Add utility to get current name scope in tf.contrib.framework.

Change: 153409845
This commit is contained in:
A. Unique TensorFlower 2017-04-17 16:18:32 -08:00 committed by TensorFlower Gardener
parent 7448b27527
commit 949d457dc9
6 changed files with 59 additions and 0 deletions

View File

@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide.
@@get_or_create_global_step
@@get_local_variables
@@get_model_variables
@@get_name_scope
@@get_trainable_variables
@@get_unique_variable
@@get_variables_by_name

View File

@ -52,3 +52,21 @@ def get_graph_from_inputs(op_input_list, graph=None):
"""
# pylint: disable=protected-access
return ops._get_graph_from_inputs(op_input_list, graph)
def get_name_scope():
"""Returns the current name scope of the default graph.
For example:
```python
with tf.name_scope('scope1'):
with tf.name_scope('scope2'):
print(tf.get_name_scope())
```
would print the string `scope1/scope2`.
Returns:
A string represnting the current name scope.
"""
return ops.get_default_graph().get_name_scope()

View File

@ -57,6 +57,15 @@ class OpsTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
ops_lib.get_graph_from_inputs(values, g1)
def testGetNameScope(self):
with ops.name_scope("scope1"):
with ops.name_scope("scope2"):
with ops.name_scope("scope3"):
self.assertEqual("scope1/scope2/scope3", ops_lib.get_name_scope())
self.assertEqual("scope1/scope2", ops_lib.get_name_scope())
self.assertEqual("scope1", ops_lib.get_name_scope())
self.assertEqual("", ops_lib.get_name_scope())
if __name__ == "__main__":
test.main()

View File

@ -2907,6 +2907,23 @@ class Graph(object):
self._names_in_use[name] = 1
return name
def get_name_scope(self):
"""Returns the current name scope.
For example:
```python
with tf.name_scope('scope1'):
with tf.name_scope('scope2'):
print(tf.get_default_graph().get_name_scope())
```
would print the string `scope1/scope2`.
Returns:
A string representing the current name scope.
"""
return self._name_stack
@contextlib.contextmanager
def colocate_with(self, op, ignore_existing=False):
"""Returns a context manager that specifies an op to colocate with.

View File

@ -1662,6 +1662,16 @@ class NameScopeTest(test_util.TensorFlowTestCase):
self.assertEqual(es, striped)
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
def testGetNameScope(self):
with ops.Graph().as_default() as g:
with ops.name_scope("scope1"):
with ops.name_scope("scope2"):
with ops.name_scope("scope3"):
self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
self.assertEqual("scope1/scope2", g.get_name_scope())
self.assertEqual("scope1", g.get_name_scope())
self.assertEqual("", g.get_name_scope())
if __name__ == "__main__":
googletest.main()

View File

@ -86,6 +86,10 @@ tf_class {
name: "get_collection_ref"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_name_scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_operation_by_name"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"