Add utility to get current name scope in tf.contrib.framework.
Change: 153409845
This commit is contained in:
parent
7448b27527
commit
949d457dc9
@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide.
|
||||
@@get_or_create_global_step
|
||||
@@get_local_variables
|
||||
@@get_model_variables
|
||||
@@get_name_scope
|
||||
@@get_trainable_variables
|
||||
@@get_unique_variable
|
||||
@@get_variables_by_name
|
||||
|
@ -52,3 +52,21 @@ def get_graph_from_inputs(op_input_list, graph=None):
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
return ops._get_graph_from_inputs(op_input_list, graph)
|
||||
|
||||
|
||||
def get_name_scope():
|
||||
"""Returns the current name scope of the default graph.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
with tf.name_scope('scope1'):
|
||||
with tf.name_scope('scope2'):
|
||||
print(tf.get_name_scope())
|
||||
```
|
||||
would print the string `scope1/scope2`.
|
||||
|
||||
Returns:
|
||||
A string represnting the current name scope.
|
||||
"""
|
||||
return ops.get_default_graph().get_name_scope()
|
||||
|
@ -57,6 +57,15 @@ class OpsTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"):
|
||||
ops_lib.get_graph_from_inputs(values, g1)
|
||||
|
||||
def testGetNameScope(self):
|
||||
with ops.name_scope("scope1"):
|
||||
with ops.name_scope("scope2"):
|
||||
with ops.name_scope("scope3"):
|
||||
self.assertEqual("scope1/scope2/scope3", ops_lib.get_name_scope())
|
||||
self.assertEqual("scope1/scope2", ops_lib.get_name_scope())
|
||||
self.assertEqual("scope1", ops_lib.get_name_scope())
|
||||
self.assertEqual("", ops_lib.get_name_scope())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -2907,6 +2907,23 @@ class Graph(object):
|
||||
self._names_in_use[name] = 1
|
||||
return name
|
||||
|
||||
def get_name_scope(self):
|
||||
"""Returns the current name scope.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
with tf.name_scope('scope1'):
|
||||
with tf.name_scope('scope2'):
|
||||
print(tf.get_default_graph().get_name_scope())
|
||||
```
|
||||
would print the string `scope1/scope2`.
|
||||
|
||||
Returns:
|
||||
A string representing the current name scope.
|
||||
"""
|
||||
return self._name_stack
|
||||
|
||||
@contextlib.contextmanager
|
||||
def colocate_with(self, op, ignore_existing=False):
|
||||
"""Returns a context manager that specifies an op to colocate with.
|
||||
|
@ -1662,6 +1662,16 @@ class NameScopeTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(es, striped)
|
||||
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
|
||||
|
||||
def testGetNameScope(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.name_scope("scope1"):
|
||||
with ops.name_scope("scope2"):
|
||||
with ops.name_scope("scope3"):
|
||||
self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
|
||||
self.assertEqual("scope1/scope2", g.get_name_scope())
|
||||
self.assertEqual("scope1", g.get_name_scope())
|
||||
self.assertEqual("", g.get_name_scope())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -86,6 +86,10 @@ tf_class {
|
||||
name: "get_collection_ref"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_name_scope"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_operation_by_name"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user